Add files using upload-large-folder tool
Browse files- LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/runs/char_ar_lta_4gpu_5k_20260507.pid +1 -0
- LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/runs/char_ar_lta_4gpu_5k_20260507_lta.pid +1 -0
- LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/runs/char_ar_lta_4gpu_5k_20260507_lta_rerun.log +61 -0
- LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/sample_fully_coupled.py +146 -0
- LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/train_char.py +618 -0
- LTA_openwebtext_dualt/logs/elfopt_4gpu_debug_20260513/lta_owt_fast10k_len1024_elfopt_muon_ema_ddit768x12_4gpu_5epoch_20260513.log +94 -0
- LTA_openwebtext_dualt/logs/elfopt_4gpu_debug_20260513/lta_owt_fast10k_len1024_elfopt_muon_ema_ddit768x12_4gpu_5epoch_20260513_trace.log +1 -0
- LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/infer_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_step_0010000_t1p45.log +68 -0
- LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/infer_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_step_0020000_t1p45.log +68 -0
- LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/infer_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_step_0030000_t1p45.log +68 -0
- LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/processed_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_steps128_c1024_t1p45_n1024.txt +4 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/activate.bat +71 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/activate.fish +124 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/f2py +10 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/pydoc.bat +22 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/audio_utils.py +1254 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/__init__.py +13 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/add_new_model_like.py +790 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/chat.py +673 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/download.py +40 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/serve.py +241 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/system.py +139 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/transformers.py +41 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/distributed/__init__.py +33 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/distributed/configuration_utils.py +110 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/hyperparameter_search.py +123 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_transforms.py +1073 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/__init__.py +30 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/configuration_gemma3.py +225 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/image_processing_gemma3.py +250 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/image_processing_pil_gemma3.py +225 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py +1118 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/modular_gemma3.py +941 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/processing_gemma3.py +165 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/__init__.py +27 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/configuration_youtu.py +107 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/modeling_youtu.py +607 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/modular_youtu.py +151 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/testing_utils.py +0 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/training_args.py +0 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_070000.pt +3 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_078000.pt +3 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_079000.pt +3 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_343000.pt +3 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_352000.pt +3 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_390000.pt +3 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_433000.pt +3 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_471000.pt +3 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_565000.pt +3 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_571000.pt +3 -0
LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/runs/char_ar_lta_4gpu_5k_20260507.pid
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
354158
|
LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/runs/char_ar_lta_4gpu_5k_20260507_lta.pid
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
354493
|
LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/runs/char_ar_lta_4gpu_5k_20260507_lta_rerun.log
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
*****************************************
|
| 3 |
+
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
| 4 |
+
*****************************************
|
| 5 |
+
[setup] device=cuda:0 rank=0 world_size=4
|
| 6 |
+
[rank2]:[W507 18:31:20.830802063 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
|
| 7 |
+
[rank0]:[W507 18:31:20.831852183 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
|
| 8 |
+
[rank3]:[W507 18:31:20.833351576 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
|
| 9 |
+
[rank1]:[W507 18:31:20.834403717 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
|
| 10 |
+
[data] chars=1115394 vocab=65 train=1003854 val=111540
|
| 11 |
+
[lta] step=1 loss=4.2253 elapsed=0.6s
|
| 12 |
+
[lta] step=100 loss=1.8800 elapsed=1.6s
|
| 13 |
+
[lta] step=200 loss=1.7154 elapsed=2.6s
|
| 14 |
+
[lta] step=300 loss=1.8433 elapsed=3.7s
|
| 15 |
+
[lta] step=400 loss=1.5804 elapsed=4.7s
|
| 16 |
+
[lta] step=500 loss=1.3491 elapsed=5.7s
|
| 17 |
+
[lta] step=600 loss=1.5344 elapsed=7.2s
|
| 18 |
+
[lta] step=700 loss=1.2540 elapsed=8.3s
|
| 19 |
+
[lta] step=800 loss=1.3757 elapsed=9.3s
|
| 20 |
+
[lta] step=900 loss=1.4476 elapsed=10.3s
|
| 21 |
+
[lta] step=1000 loss=1.4170 elapsed=11.3s
|
| 22 |
+
[lta] step=1100 loss=1.5610 elapsed=12.8s
|
| 23 |
+
[lta] step=1200 loss=1.4933 elapsed=13.8s
|
| 24 |
+
[lta] step=1300 loss=1.5656 elapsed=14.9s
|
| 25 |
+
[lta] step=1400 loss=1.5198 elapsed=15.9s
|
| 26 |
+
[lta] step=1500 loss=1.4798 elapsed=17.0s
|
| 27 |
+
[lta] step=1600 loss=1.5783 elapsed=18.5s
|
| 28 |
+
[lta] step=1700 loss=1.1984 elapsed=19.5s
|
| 29 |
+
[lta] step=1800 loss=1.2941 elapsed=20.5s
|
| 30 |
+
[lta] step=1900 loss=1.5220 elapsed=21.5s
|
| 31 |
+
[lta] step=2000 loss=1.2615 elapsed=22.6s
|
| 32 |
+
[lta] step=2100 loss=1.3370 elapsed=24.1s
|
| 33 |
+
[lta] step=2200 loss=1.1854 elapsed=25.1s
|
| 34 |
+
[lta] step=2300 loss=0.9726 elapsed=26.1s
|
| 35 |
+
[lta] step=2400 loss=1.4613 elapsed=27.1s
|
| 36 |
+
[lta] step=2500 loss=1.3016 elapsed=28.2s
|
| 37 |
+
[lta] step=2600 loss=1.3408 elapsed=29.7s
|
| 38 |
+
[lta] step=2700 loss=1.3022 elapsed=30.7s
|
| 39 |
+
[lta] step=2800 loss=1.4492 elapsed=31.7s
|
| 40 |
+
[lta] step=2900 loss=1.1530 elapsed=32.7s
|
| 41 |
+
[lta] step=3000 loss=1.4642 elapsed=33.8s
|
| 42 |
+
[lta] step=3100 loss=1.2645 elapsed=35.3s
|
| 43 |
+
[lta] step=3200 loss=1.4777 elapsed=36.3s
|
| 44 |
+
[lta] step=3300 loss=1.0923 elapsed=37.4s
|
| 45 |
+
[lta] step=3400 loss=1.1992 elapsed=38.5s
|
| 46 |
+
[lta] step=3500 loss=1.4760 elapsed=39.6s
|
| 47 |
+
[lta] step=3600 loss=1.5702 elapsed=41.0s
|
| 48 |
+
[lta] step=3700 loss=1.5327 elapsed=42.1s
|
| 49 |
+
[lta] step=3800 loss=1.5319 elapsed=43.1s
|
| 50 |
+
[lta] step=3900 loss=1.3098 elapsed=44.3s
|
| 51 |
+
[lta] step=4000 loss=1.6050 elapsed=45.3s
|
| 52 |
+
[lta] step=4100 loss=1.2478 elapsed=46.8s
|
| 53 |
+
[lta] step=4200 loss=1.3497 elapsed=47.8s
|
| 54 |
+
[lta] step=4300 loss=1.3263 elapsed=48.8s
|
| 55 |
+
[lta] step=4400 loss=1.3406 elapsed=49.9s
|
| 56 |
+
[lta] step=4500 loss=1.3295 elapsed=50.9s
|
| 57 |
+
[lta] step=4600 loss=1.5340 elapsed=52.4s
|
| 58 |
+
[lta] step=4700 loss=1.4847 elapsed=53.4s
|
| 59 |
+
[lta] step=4800 loss=1.1464 elapsed=54.4s
|
| 60 |
+
[lta] step=4900 loss=1.4102 elapsed=55.5s
|
| 61 |
+
[lta] step=5000 loss=1.4638 elapsed=56.5s
|
LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/sample_fully_coupled.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import csv
|
| 6 |
+
import json
|
| 7 |
+
import math
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 15 |
+
if str(SCRIPT_DIR) not in sys.path:
|
| 16 |
+
sys.path.insert(0, str(SCRIPT_DIR))
|
| 17 |
+
|
| 18 |
+
from train_char import CharTokenizer, ModelConfig, TinyTransformer, standard_gamma, text_stats
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def pick_device(name: str) -> torch.device:
|
| 22 |
+
if name != "auto":
|
| 23 |
+
return torch.device(name)
|
| 24 |
+
if torch.cuda.is_available():
|
| 25 |
+
return torch.device("cuda")
|
| 26 |
+
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
| 27 |
+
return torch.device("mps")
|
| 28 |
+
return torch.device("cpu")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def decode(
|
| 32 |
+
model: TinyTransformer,
|
| 33 |
+
tokenizer: CharTokenizer,
|
| 34 |
+
*,
|
| 35 |
+
length: int,
|
| 36 |
+
steps: int,
|
| 37 |
+
c_min: float,
|
| 38 |
+
c_max: float,
|
| 39 |
+
temp: float,
|
| 40 |
+
final_from: str,
|
| 41 |
+
seed: int,
|
| 42 |
+
device: torch.device,
|
| 43 |
+
) -> str:
|
| 44 |
+
torch.manual_seed(seed)
|
| 45 |
+
eps = 1e-8
|
| 46 |
+
vocab_size = tokenizer.vocab_size
|
| 47 |
+
alpha = torch.full((1, length, vocab_size), 1.0 / vocab_size, device=device).clamp_min(eps)
|
| 48 |
+
probs = standard_gamma(alpha).clamp_min(eps)
|
| 49 |
+
probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(eps)
|
| 50 |
+
last_endpoint = probs
|
| 51 |
+
for step in range(steps):
|
| 52 |
+
t_value = (step + 1) / max(steps, 1)
|
| 53 |
+
t = torch.full((1,), t_value, device=device)
|
| 54 |
+
logits = model(probs, t) / temp
|
| 55 |
+
endpoint = F.softmax(logits, dim=-1)
|
| 56 |
+
last_endpoint = endpoint
|
| 57 |
+
support_t = t_value
|
| 58 |
+
semantic_t = t_value
|
| 59 |
+
forward_endpoint = (1.0 - semantic_t) * probs + semantic_t * endpoint
|
| 60 |
+
mean = (1.0 - support_t) / float(vocab_size) + support_t * forward_endpoint
|
| 61 |
+
mean = mean.clamp_min(eps)
|
| 62 |
+
mean = mean / mean.sum(dim=-1, keepdim=True).clamp_min(eps)
|
| 63 |
+
conc = math.exp(math.log(c_min) + support_t * math.log(c_max / c_min))
|
| 64 |
+
sample = standard_gamma((mean * conc).clamp_min(eps)).clamp_min(eps)
|
| 65 |
+
probs = sample / sample.sum(dim=-1, keepdim=True).clamp_min(eps)
|
| 66 |
+
if final_from == "state":
|
| 67 |
+
final = probs
|
| 68 |
+
elif final_from == "endpoint":
|
| 69 |
+
final = last_endpoint
|
| 70 |
+
elif final_from == "blend":
|
| 71 |
+
final = 0.5 * probs + 0.5 * last_endpoint
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError(final_from)
|
| 74 |
+
ids = final.argmax(dim=-1)[0]
|
| 75 |
+
return tokenizer.decode(ids)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def main() -> None:
|
| 79 |
+
p = argparse.ArgumentParser()
|
| 80 |
+
p.add_argument("--checkpoint", required=True)
|
| 81 |
+
p.add_argument("--out_dir", required=True)
|
| 82 |
+
p.add_argument("--length", type=int, default=128)
|
| 83 |
+
p.add_argument("--seed", type=int, default=20260507)
|
| 84 |
+
p.add_argument("--device", default="auto")
|
| 85 |
+
args = p.parse_args()
|
| 86 |
+
|
| 87 |
+
device = pick_device(args.device)
|
| 88 |
+
print(f"[setup] device={device}", flush=True)
|
| 89 |
+
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
| 90 |
+
cfg = ModelConfig(**ckpt["model_config"])
|
| 91 |
+
tok_data = ckpt["extra"]["tokenizer"]
|
| 92 |
+
tokenizer = CharTokenizer("".join(tok_data["itos"]))
|
| 93 |
+
tokenizer.itos = tok_data["itos"]
|
| 94 |
+
tokenizer.stoi = tok_data["stoi"]
|
| 95 |
+
tokenizer.vocab_size = tok_data["vocab_size"]
|
| 96 |
+
model = TinyTransformer(cfg).to(device)
|
| 97 |
+
model.load_state_dict(ckpt["model"])
|
| 98 |
+
model.eval()
|
| 99 |
+
|
| 100 |
+
configs = []
|
| 101 |
+
for steps in [128, 256, 512, 1024]:
|
| 102 |
+
for c_max in [64.0, 16.0, 4.0, 1.0]:
|
| 103 |
+
for temp in [0.8, 1.0, 1.3, 1.8]:
|
| 104 |
+
for final_from in ["state", "endpoint", "blend"]:
|
| 105 |
+
configs.append((steps, c_max, temp, final_from))
|
| 106 |
+
|
| 107 |
+
out_dir = Path(args.out_dir)
|
| 108 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 109 |
+
rows = []
|
| 110 |
+
for i, (steps, c_max, temp, final_from) in enumerate(configs):
|
| 111 |
+
name = f"steps{steps}_c{c_max:g}_temp{str(temp).replace('.', 'p')}_{final_from}"
|
| 112 |
+
text = decode(
|
| 113 |
+
model,
|
| 114 |
+
tokenizer,
|
| 115 |
+
length=args.length,
|
| 116 |
+
steps=steps,
|
| 117 |
+
c_min=1.0,
|
| 118 |
+
c_max=c_max,
|
| 119 |
+
temp=temp,
|
| 120 |
+
final_from=final_from,
|
| 121 |
+
seed=args.seed,
|
| 122 |
+
device=device,
|
| 123 |
+
)
|
| 124 |
+
stats = text_stats(text)
|
| 125 |
+
row = {
|
| 126 |
+
"name": name,
|
| 127 |
+
"steps": steps,
|
| 128 |
+
"c_max": c_max,
|
| 129 |
+
"temp": temp,
|
| 130 |
+
"final_from": final_from,
|
| 131 |
+
**stats,
|
| 132 |
+
}
|
| 133 |
+
rows.append(row)
|
| 134 |
+
(out_dir / f"{name}.txt").write_text(text, encoding="utf-8")
|
| 135 |
+
if i % 12 == 0:
|
| 136 |
+
print("[sample]", row, flush=True)
|
| 137 |
+
|
| 138 |
+
keys = list(rows[0].keys())
|
| 139 |
+
with (out_dir / "summary.tsv").open("w", encoding="utf-8") as f:
|
| 140 |
+
writer = csv.DictWriter(f, fieldnames=keys, delimiter="\t")
|
| 141 |
+
writer.writeheader()
|
| 142 |
+
writer.writerows(rows)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|
LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/train_char.py
ADDED
|
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import urllib.request
|
| 10 |
+
from dataclasses import asdict, dataclass
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
DATA_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def setup_distributed(name: str) -> tuple[torch.device, int, int, int, bool]:
|
| 24 |
+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
| 25 |
+
rank = int(os.environ.get("RANK", "0"))
|
| 26 |
+
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
| 27 |
+
if world_size > 1:
|
| 28 |
+
if not torch.cuda.is_available():
|
| 29 |
+
raise RuntimeError("DDP mode expects CUDA. Run single-process for CPU/MPS.")
|
| 30 |
+
torch.cuda.set_device(local_rank)
|
| 31 |
+
dist.init_process_group(backend="nccl")
|
| 32 |
+
return torch.device("cuda", local_rank), rank, local_rank, world_size, True
|
| 33 |
+
if name != "auto":
|
| 34 |
+
return torch.device(name), rank, local_rank, world_size, False
|
| 35 |
+
if torch.cuda.is_available():
|
| 36 |
+
return torch.device("cuda"), rank, local_rank, world_size, False
|
| 37 |
+
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
| 38 |
+
return torch.device("mps"), rank, local_rank, world_size, False
|
| 39 |
+
return torch.device("cpu"), rank, local_rank, world_size, False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def cleanup_distributed(is_ddp: bool) -> None:
|
| 43 |
+
if is_ddp and dist.is_initialized():
|
| 44 |
+
dist.destroy_process_group()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def is_main_process(rank: int) -> bool:
|
| 48 |
+
return rank == 0
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class CharTokenizer:
|
| 52 |
+
def __init__(self, text: str):
|
| 53 |
+
chars = sorted(set(text))
|
| 54 |
+
self.itos = chars
|
| 55 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
| 56 |
+
self.vocab_size = len(chars)
|
| 57 |
+
|
| 58 |
+
def encode(self, text: str) -> list[int]:
|
| 59 |
+
return [self.stoi[ch] for ch in text]
|
| 60 |
+
|
| 61 |
+
def decode(self, ids: list[int] | torch.Tensor) -> str:
|
| 62 |
+
if isinstance(ids, torch.Tensor):
|
| 63 |
+
ids = ids.detach().cpu().tolist()
|
| 64 |
+
return "".join(self.itos[int(i)] for i in ids)
|
| 65 |
+
|
| 66 |
+
def to_json(self) -> dict:
|
| 67 |
+
return {"itos": self.itos, "stoi": self.stoi, "vocab_size": self.vocab_size}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def ensure_tinyshakespeare(data_dir: Path) -> None:
|
| 71 |
+
data_dir.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
path = data_dir / "input.txt"
|
| 73 |
+
if not path.exists():
|
| 74 |
+
print(f"[data] downloading {DATA_URL}", flush=True)
|
| 75 |
+
urllib.request.urlretrieve(DATA_URL, path)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def load_tinyshakespeare(data_dir: Path) -> tuple[str, CharTokenizer, torch.Tensor, torch.Tensor]:
|
| 79 |
+
path = data_dir / "input.txt"
|
| 80 |
+
text = path.read_text(encoding="utf-8")
|
| 81 |
+
tokenizer = CharTokenizer(text)
|
| 82 |
+
ids = torch.tensor(tokenizer.encode(text), dtype=torch.long)
|
| 83 |
+
split = int(0.9 * len(ids))
|
| 84 |
+
return text, tokenizer, ids[:split], ids[split:]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_batch(data: torch.Tensor, *, batch_size: int, block_size: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
| 88 |
+
ix = torch.randint(0, len(data) - block_size - 1, (batch_size,))
|
| 89 |
+
x = torch.stack([data[i : i + block_size] for i in ix])
|
| 90 |
+
y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
|
| 91 |
+
return x.to(device), y.to(device)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_block_batch(data: torch.Tensor, *, batch_size: int, block_size: int, device: torch.device) -> torch.Tensor:
|
| 95 |
+
ix = torch.randint(0, len(data) - block_size, (batch_size,))
|
| 96 |
+
x = torch.stack([data[i : i + block_size] for i in ix])
|
| 97 |
+
return x.to(device)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _wrapped_window(data: torch.Tensor, start: int, width: int) -> torch.Tensor:
|
| 101 |
+
end = start + width
|
| 102 |
+
if end <= len(data):
|
| 103 |
+
return data[start:end]
|
| 104 |
+
return torch.cat([data[start:], data[: end % len(data)]], dim=0)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_stream_batch(
|
| 108 |
+
data: torch.Tensor,
|
| 109 |
+
*,
|
| 110 |
+
batch_size: int,
|
| 111 |
+
block_size: int,
|
| 112 |
+
device: torch.device,
|
| 113 |
+
cursor: int,
|
| 114 |
+
rank: int = 0,
|
| 115 |
+
world_size: int = 1,
|
| 116 |
+
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
| 117 |
+
width = block_size + 1
|
| 118 |
+
base = (cursor + rank * batch_size * width) % len(data)
|
| 119 |
+
samples = torch.stack([_wrapped_window(data, (base + i * width) % len(data), width) for i in range(batch_size)])
|
| 120 |
+
next_cursor = (cursor + world_size * batch_size * width) % len(data)
|
| 121 |
+
return samples[:, :block_size].to(device), samples[:, 1:].to(device), next_cursor
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_stream_block_batch(
|
| 125 |
+
data: torch.Tensor,
|
| 126 |
+
*,
|
| 127 |
+
batch_size: int,
|
| 128 |
+
block_size: int,
|
| 129 |
+
device: torch.device,
|
| 130 |
+
cursor: int,
|
| 131 |
+
rank: int = 0,
|
| 132 |
+
world_size: int = 1,
|
| 133 |
+
) -> tuple[torch.Tensor, int]:
|
| 134 |
+
width = block_size
|
| 135 |
+
base = (cursor + rank * batch_size * width) % len(data)
|
| 136 |
+
samples = torch.stack([_wrapped_window(data, (base + i * width) % len(data), width) for i in range(batch_size)])
|
| 137 |
+
next_cursor = (cursor + world_size * batch_size * width) % len(data)
|
| 138 |
+
return samples.to(device), next_cursor
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@dataclass
|
| 142 |
+
class ModelConfig:
|
| 143 |
+
vocab_size: int
|
| 144 |
+
block_size: int = 128
|
| 145 |
+
n_layer: int = 4
|
| 146 |
+
n_head: int = 4
|
| 147 |
+
n_embd: int = 128
|
| 148 |
+
dropout: float = 0.1
|
| 149 |
+
causal: bool = True
|
| 150 |
+
input_kind: str = "tokens"
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class CausalSelfAttention(nn.Module):
|
| 154 |
+
def __init__(self, cfg: ModelConfig):
|
| 155 |
+
super().__init__()
|
| 156 |
+
assert cfg.n_embd % cfg.n_head == 0
|
| 157 |
+
self.n_head = cfg.n_head
|
| 158 |
+
self.dropout = cfg.dropout
|
| 159 |
+
self.c_attn = nn.Linear(cfg.n_embd, 3 * cfg.n_embd)
|
| 160 |
+
self.c_proj = nn.Linear(cfg.n_embd, cfg.n_embd)
|
| 161 |
+
self.attn_drop = nn.Dropout(cfg.dropout)
|
| 162 |
+
self.resid_drop = nn.Dropout(cfg.dropout)
|
| 163 |
+
self.causal = cfg.causal
|
| 164 |
+
self.register_buffer("tril", torch.tril(torch.ones(cfg.block_size, cfg.block_size)).view(1, 1, cfg.block_size, cfg.block_size))
|
| 165 |
+
|
| 166 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
b, t, c = x.shape
|
| 168 |
+
q, k, v = self.c_attn(x).split(c, dim=2)
|
| 169 |
+
q = q.view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
|
| 170 |
+
k = k.view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
|
| 171 |
+
v = v.view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
|
| 172 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 173 |
+
if self.causal:
|
| 174 |
+
att = att.masked_fill(self.tril[:, :, :t, :t] == 0, float("-inf"))
|
| 175 |
+
att = F.softmax(att, dim=-1)
|
| 176 |
+
att = self.attn_drop(att)
|
| 177 |
+
y = att @ v
|
| 178 |
+
y = y.transpose(1, 2).contiguous().view(b, t, c)
|
| 179 |
+
return self.resid_drop(self.c_proj(y))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class Block(nn.Module):
|
| 183 |
+
def __init__(self, cfg: ModelConfig):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.ln1 = nn.LayerNorm(cfg.n_embd)
|
| 186 |
+
self.attn = CausalSelfAttention(cfg)
|
| 187 |
+
self.ln2 = nn.LayerNorm(cfg.n_embd)
|
| 188 |
+
self.mlp = nn.Sequential(
|
| 189 |
+
nn.Linear(cfg.n_embd, 4 * cfg.n_embd),
|
| 190 |
+
nn.GELU(),
|
| 191 |
+
nn.Linear(4 * cfg.n_embd, cfg.n_embd),
|
| 192 |
+
nn.Dropout(cfg.dropout),
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 196 |
+
x = x + self.attn(self.ln1(x))
|
| 197 |
+
x = x + self.mlp(self.ln2(x))
|
| 198 |
+
return x
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class TinyTransformer(nn.Module):
|
| 202 |
+
def __init__(self, cfg: ModelConfig):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.cfg = cfg
|
| 205 |
+
self.token_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd)
|
| 206 |
+
self.prob_proj = nn.Linear(cfg.vocab_size, cfg.n_embd, bias=False)
|
| 207 |
+
self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd)
|
| 208 |
+
self.time_mlp = nn.Sequential(nn.Linear(1, cfg.n_embd), nn.SiLU(), nn.Linear(cfg.n_embd, cfg.n_embd))
|
| 209 |
+
self.drop = nn.Dropout(cfg.dropout)
|
| 210 |
+
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
|
| 211 |
+
self.ln_f = nn.LayerNorm(cfg.n_embd)
|
| 212 |
+
self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
|
| 213 |
+
if cfg.input_kind == "tokens":
|
| 214 |
+
self.lm_head.weight = self.token_emb.weight
|
| 215 |
+
self.apply(self._init_weights)
|
| 216 |
+
|
| 217 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 218 |
+
if isinstance(module, nn.Linear):
|
| 219 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 220 |
+
if module.bias is not None:
|
| 221 |
+
nn.init.zeros_(module.bias)
|
| 222 |
+
elif isinstance(module, nn.Embedding):
|
| 223 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 224 |
+
|
| 225 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor | None = None) -> torch.Tensor:
|
| 226 |
+
b, seq_len = x.shape[:2]
|
| 227 |
+
pos = torch.arange(seq_len, device=x.device)
|
| 228 |
+
if self.cfg.input_kind == "tokens":
|
| 229 |
+
h = self.token_emb(x)
|
| 230 |
+
elif self.cfg.input_kind == "probs":
|
| 231 |
+
h = self.prob_proj(x.float())
|
| 232 |
+
if t is None:
|
| 233 |
+
raise ValueError("LTA/prob model requires time t")
|
| 234 |
+
h = h + self.time_mlp(t.float().view(b, 1)).view(b, 1, -1)
|
| 235 |
+
else:
|
| 236 |
+
raise ValueError(f"unknown input_kind: {self.cfg.input_kind}")
|
| 237 |
+
h = self.drop(h + self.pos_emb(pos).view(1, seq_len, -1))
|
| 238 |
+
for block in self.blocks:
|
| 239 |
+
h = block(h)
|
| 240 |
+
return self.lm_head(self.ln_f(h))
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@dataclass
|
| 244 |
+
class LTAConfig:
|
| 245 |
+
c_min: float = 1.0
|
| 246 |
+
c_max: float = 64.0
|
| 247 |
+
endpoint_mode: str = "full_vocab_wrong"
|
| 248 |
+
t_mode: str = "same"
|
| 249 |
+
eps: float = 1e-8
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def concentration(t: torch.Tensor, c_min: float, c_max: float) -> torch.Tensor:
|
| 253 |
+
return torch.exp(torch.log(torch.tensor(c_min, device=t.device)) + t * math.log(c_max / c_min))
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def standard_gamma(alpha: torch.Tensor) -> torch.Tensor:
|
| 257 |
+
# MPS does not implement aten::_standard_gamma yet. Sampling on CPU is
|
| 258 |
+
# plenty fast for this tiny char-level experiment, while the Transformer
|
| 259 |
+
# still runs on the accelerator.
|
| 260 |
+
if alpha.device.type == "mps":
|
| 261 |
+
return torch._standard_gamma(alpha.cpu()).to(alpha.device)
|
| 262 |
+
return torch._standard_gamma(alpha)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def corrupt_categorical_simplex(ids: torch.Tensor, vocab_size: int, cfg: LTAConfig) -> tuple[torch.Tensor, torch.Tensor]:
|
| 266 |
+
b, seq_len = ids.shape
|
| 267 |
+
device = ids.device
|
| 268 |
+
support_t = torch.rand(b, device=device)
|
| 269 |
+
if cfg.t_mode == "same":
|
| 270 |
+
semantic_t = support_t
|
| 271 |
+
elif cfg.t_mode == "independent":
|
| 272 |
+
semantic_t = torch.rand(b, device=device)
|
| 273 |
+
else:
|
| 274 |
+
raise ValueError(f"unknown t_mode: {cfg.t_mode}")
|
| 275 |
+
|
| 276 |
+
gold = F.one_hot(ids, vocab_size).float()
|
| 277 |
+
wrong_ids = torch.randint(0, vocab_size, ids.shape, device=device)
|
| 278 |
+
wrong = F.one_hot(wrong_ids, vocab_size).float()
|
| 279 |
+
endpoint = semantic_t.view(b, 1, 1) * gold + (1.0 - semantic_t).view(b, 1, 1) * wrong
|
| 280 |
+
|
| 281 |
+
support = support_t.view(b, 1, 1)
|
| 282 |
+
mean = (1.0 - support) / float(vocab_size) + support * endpoint
|
| 283 |
+
mean = mean.clamp_min(cfg.eps)
|
| 284 |
+
mean = mean / mean.sum(dim=-1, keepdim=True).clamp_min(cfg.eps)
|
| 285 |
+
conc = concentration(support_t, cfg.c_min, cfg.c_max).view(b, 1, 1)
|
| 286 |
+
alpha = (mean * conc).clamp_min(cfg.eps)
|
| 287 |
+
state = standard_gamma(alpha).clamp_min(cfg.eps)
|
| 288 |
+
state = state / state.sum(dim=-1, keepdim=True).clamp_min(cfg.eps)
|
| 289 |
+
return state, support_t
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
@torch.no_grad()
|
| 293 |
+
def estimate_ar_loss(model: TinyTransformer, data: torch.Tensor, args, device: torch.device, eval_iters: int) -> float:
|
| 294 |
+
model.eval()
|
| 295 |
+
losses = []
|
| 296 |
+
for _ in range(eval_iters):
|
| 297 |
+
x, y = get_batch(data, batch_size=args.batch_size, block_size=args.block_size, device=device)
|
| 298 |
+
logits = model(x)
|
| 299 |
+
losses.append(F.cross_entropy(logits.view(-1, logits.size(-1)), y.reshape(-1)).item())
|
| 300 |
+
model.train()
|
| 301 |
+
return float(sum(losses) / len(losses))
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@torch.no_grad()
|
| 305 |
+
def estimate_lta_loss(model: TinyTransformer, data: torch.Tensor, args, lta_cfg: LTAConfig, device: torch.device, eval_iters: int) -> float:
|
| 306 |
+
model.eval()
|
| 307 |
+
losses = []
|
| 308 |
+
for _ in range(eval_iters):
|
| 309 |
+
ids = get_block_batch(data, batch_size=args.batch_size, block_size=args.block_size, device=device)
|
| 310 |
+
state, t = corrupt_categorical_simplex(ids, model.cfg.vocab_size, lta_cfg)
|
| 311 |
+
logits = model(state, t)
|
| 312 |
+
losses.append(F.cross_entropy(logits.view(-1, logits.size(-1)), ids.reshape(-1)).item())
|
| 313 |
+
model.train()
|
| 314 |
+
return float(sum(losses) / len(losses))
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
@torch.no_grad()
|
| 318 |
+
def generate_ar(model: TinyTransformer, tokenizer: CharTokenizer, *, length: int, temp: float, device: torch.device, seed: str = "\n") -> str:
|
| 319 |
+
model.eval()
|
| 320 |
+
ids = torch.tensor([tokenizer.encode(seed)], dtype=torch.long, device=device)
|
| 321 |
+
for _ in range(length):
|
| 322 |
+
idx = ids[:, -model.cfg.block_size :]
|
| 323 |
+
logits = model(idx)[:, -1, :]
|
| 324 |
+
if temp <= 0:
|
| 325 |
+
nxt = logits.argmax(dim=-1, keepdim=True)
|
| 326 |
+
else:
|
| 327 |
+
probs = F.softmax(logits / temp, dim=-1)
|
| 328 |
+
nxt = torch.multinomial(probs, num_samples=1)
|
| 329 |
+
ids = torch.cat([ids, nxt], dim=1)
|
| 330 |
+
return tokenizer.decode(ids[0])
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
@torch.no_grad()
|
| 334 |
+
def generate_lta(
|
| 335 |
+
model: TinyTransformer,
|
| 336 |
+
tokenizer: CharTokenizer,
|
| 337 |
+
*,
|
| 338 |
+
length: int,
|
| 339 |
+
steps: int,
|
| 340 |
+
c_min: float,
|
| 341 |
+
c_max: float,
|
| 342 |
+
temp: float,
|
| 343 |
+
device: torch.device,
|
| 344 |
+
) -> str:
|
| 345 |
+
model.eval()
|
| 346 |
+
eps = 1e-8
|
| 347 |
+
vocab_size = tokenizer.vocab_size
|
| 348 |
+
alpha = torch.full((1, length, vocab_size), 1.0 / vocab_size, device=device).clamp_min(eps)
|
| 349 |
+
probs = standard_gamma(alpha).clamp_min(eps)
|
| 350 |
+
probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(eps)
|
| 351 |
+
last_endpoint = probs
|
| 352 |
+
for step in range(steps):
|
| 353 |
+
t_value = (step + 1) / max(steps, 1)
|
| 354 |
+
t = torch.full((1,), t_value, device=device)
|
| 355 |
+
logits = model(probs, t) / temp
|
| 356 |
+
endpoint = F.softmax(logits, dim=-1)
|
| 357 |
+
last_endpoint = endpoint
|
| 358 |
+
support_t = t_value
|
| 359 |
+
semantic_t = t_value
|
| 360 |
+
forward_endpoint = (1.0 - semantic_t) * probs + semantic_t * endpoint
|
| 361 |
+
mean = (1.0 - support_t) / float(vocab_size) + support_t * forward_endpoint
|
| 362 |
+
mean = mean.clamp_min(eps)
|
| 363 |
+
mean = mean / mean.sum(dim=-1, keepdim=True).clamp_min(eps)
|
| 364 |
+
conc = math.exp(math.log(c_min) + support_t * math.log(c_max / c_min))
|
| 365 |
+
sample = standard_gamma((mean * conc).clamp_min(eps)).clamp_min(eps)
|
| 366 |
+
probs = sample / sample.sum(dim=-1, keepdim=True).clamp_min(eps)
|
| 367 |
+
final = 0.5 * probs + 0.5 * last_endpoint
|
| 368 |
+
ids = final.argmax(dim=-1)[0]
|
| 369 |
+
return tokenizer.decode(ids)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def text_stats(text: str) -> dict[str, float]:
|
| 373 |
+
chars = list(text)
|
| 374 |
+
counts = {}
|
| 375 |
+
for ch in chars:
|
| 376 |
+
counts[ch] = counts.get(ch, 0) + 1
|
| 377 |
+
n = max(len(chars), 1)
|
| 378 |
+
entropy = -sum((c / n) * math.log(c / n) for c in counts.values())
|
| 379 |
+
bigrams = list(zip(chars, chars[1:]))
|
| 380 |
+
distinct_2 = len(set(bigrams)) / max(len(bigrams), 1)
|
| 381 |
+
return {"char_entropy": entropy, "distinct_2": distinct_2, "length": float(len(text))}
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def save_checkpoint(path: Path, model: TinyTransformer, optimizer: torch.optim.Optimizer, step: int, cfg: ModelConfig, extra: dict) -> None:
|
| 385 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 386 |
+
torch.save(
|
| 387 |
+
{
|
| 388 |
+
"model": model.state_dict(),
|
| 389 |
+
"optimizer": optimizer.state_dict(),
|
| 390 |
+
"step": step,
|
| 391 |
+
"model_config": asdict(cfg),
|
| 392 |
+
"extra": extra,
|
| 393 |
+
},
|
| 394 |
+
path,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def maybe_resume(path: str | None, model: TinyTransformer, optimizer: torch.optim.Optimizer, device: torch.device) -> int:
|
| 399 |
+
if not path:
|
| 400 |
+
return 0
|
| 401 |
+
ckpt = torch.load(path, map_location=device)
|
| 402 |
+
model.load_state_dict(ckpt["model"])
|
| 403 |
+
if "optimizer" in ckpt:
|
| 404 |
+
optimizer.load_state_dict(ckpt["optimizer"])
|
| 405 |
+
return int(ckpt.get("step", 0))
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def train_ar(
|
| 409 |
+
args,
|
| 410 |
+
tokenizer: CharTokenizer,
|
| 411 |
+
train_data: torch.Tensor,
|
| 412 |
+
val_data: torch.Tensor,
|
| 413 |
+
device: torch.device,
|
| 414 |
+
*,
|
| 415 |
+
rank: int,
|
| 416 |
+
local_rank: int,
|
| 417 |
+
is_ddp: bool,
|
| 418 |
+
) -> None:
|
| 419 |
+
cfg = ModelConfig(
|
| 420 |
+
vocab_size=tokenizer.vocab_size,
|
| 421 |
+
block_size=args.block_size,
|
| 422 |
+
n_layer=args.n_layer,
|
| 423 |
+
n_head=args.n_head,
|
| 424 |
+
n_embd=args.n_embd,
|
| 425 |
+
dropout=args.dropout,
|
| 426 |
+
causal=True,
|
| 427 |
+
input_kind="tokens",
|
| 428 |
+
)
|
| 429 |
+
raw_model = TinyTransformer(cfg).to(device)
|
| 430 |
+
model = DDP(raw_model, device_ids=[local_rank], find_unused_parameters=True) if is_ddp else raw_model
|
| 431 |
+
optimizer = torch.optim.AdamW(raw_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
| 432 |
+
start_step = maybe_resume(args.resume_path, raw_model, optimizer, device)
|
| 433 |
+
out_dir = Path(args.out_dir) / "ar"
|
| 434 |
+
if is_main_process(rank):
|
| 435 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 436 |
+
log_path = out_dir / "metrics.jsonl"
|
| 437 |
+
t0 = time.time()
|
| 438 |
+
stream_world_size = dist.get_world_size() if is_ddp and dist.is_initialized() else 1
|
| 439 |
+
stream_cursor = (start_step * args.batch_size * (args.block_size + 1) * stream_world_size) % len(train_data)
|
| 440 |
+
for step in range(start_step + 1, args.steps + 1):
|
| 441 |
+
if args.data_mode == "stream":
|
| 442 |
+
x, y, stream_cursor = get_stream_batch(
|
| 443 |
+
train_data,
|
| 444 |
+
batch_size=args.batch_size,
|
| 445 |
+
block_size=args.block_size,
|
| 446 |
+
device=device,
|
| 447 |
+
cursor=stream_cursor,
|
| 448 |
+
rank=rank,
|
| 449 |
+
world_size=stream_world_size,
|
| 450 |
+
)
|
| 451 |
+
else:
|
| 452 |
+
x, y = get_batch(train_data, batch_size=args.batch_size, block_size=args.block_size, device=device)
|
| 453 |
+
logits = model(x)
|
| 454 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.reshape(-1))
|
| 455 |
+
optimizer.zero_grad(set_to_none=True)
|
| 456 |
+
loss.backward()
|
| 457 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
| 458 |
+
optimizer.step()
|
| 459 |
+
if is_main_process(rank) and (step % args.log_interval == 0 or step == 1):
|
| 460 |
+
print(f"[ar] step={step} loss={loss.item():.4f} elapsed={time.time() - t0:.1f}s", flush=True)
|
| 461 |
+
if step % args.eval_interval == 0 or step == args.steps:
|
| 462 |
+
if is_ddp:
|
| 463 |
+
dist.barrier()
|
| 464 |
+
if is_main_process(rank):
|
| 465 |
+
val_loss = estimate_ar_loss(raw_model, val_data, args, device, args.eval_iters)
|
| 466 |
+
sample = generate_ar(raw_model, tokenizer, length=args.sample_len, temp=args.ar_temp, device=device)
|
| 467 |
+
row = {"step": step, "train_loss": float(loss.item()), "val_loss": val_loss, **text_stats(sample)}
|
| 468 |
+
with log_path.open("a", encoding="utf-8") as f:
|
| 469 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 470 |
+
(out_dir / f"sample_step{step:05d}.txt").write_text(sample, encoding="utf-8")
|
| 471 |
+
save_checkpoint(out_dir / "latest.pt", raw_model, optimizer, step, cfg, {"mode": "ar", "data_mode": args.data_mode, "tokenizer": tokenizer.to_json()})
|
| 472 |
+
if is_ddp:
|
| 473 |
+
dist.barrier()
|
| 474 |
+
if is_main_process(rank):
|
| 475 |
+
save_checkpoint(out_dir / f"step_{args.steps:05d}.pt", raw_model, optimizer, args.steps, cfg, {"mode": "ar", "data_mode": args.data_mode, "tokenizer": tokenizer.to_json()})
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def train_lta(
|
| 479 |
+
args,
|
| 480 |
+
tokenizer: CharTokenizer,
|
| 481 |
+
train_data: torch.Tensor,
|
| 482 |
+
val_data: torch.Tensor,
|
| 483 |
+
device: torch.device,
|
| 484 |
+
*,
|
| 485 |
+
rank: int,
|
| 486 |
+
local_rank: int,
|
| 487 |
+
is_ddp: bool,
|
| 488 |
+
) -> None:
|
| 489 |
+
cfg = ModelConfig(
|
| 490 |
+
vocab_size=tokenizer.vocab_size,
|
| 491 |
+
block_size=args.block_size,
|
| 492 |
+
n_layer=args.n_layer,
|
| 493 |
+
n_head=args.n_head,
|
| 494 |
+
n_embd=args.n_embd,
|
| 495 |
+
dropout=args.dropout,
|
| 496 |
+
causal=False,
|
| 497 |
+
input_kind="probs",
|
| 498 |
+
)
|
| 499 |
+
lta_cfg = LTAConfig(c_min=args.c_min, c_max=args.c_max, t_mode=args.t_mode)
|
| 500 |
+
raw_model = TinyTransformer(cfg).to(device)
|
| 501 |
+
model = DDP(raw_model, device_ids=[local_rank], find_unused_parameters=True) if is_ddp else raw_model
|
| 502 |
+
optimizer = torch.optim.AdamW(raw_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
| 503 |
+
start_step = maybe_resume(args.resume_path, raw_model, optimizer, device)
|
| 504 |
+
out_dir = Path(args.out_dir) / ("fully_coupled" if args.mode == "fully_coupled" else "lta")
|
| 505 |
+
if is_main_process(rank):
|
| 506 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 507 |
+
log_path = out_dir / "metrics.jsonl"
|
| 508 |
+
t0 = time.time()
|
| 509 |
+
stream_world_size = dist.get_world_size() if is_ddp and dist.is_initialized() else 1
|
| 510 |
+
stream_cursor = (start_step * args.batch_size * args.block_size * stream_world_size) % len(train_data)
|
| 511 |
+
for step in range(start_step + 1, args.steps + 1):
|
| 512 |
+
if args.data_mode == "stream":
|
| 513 |
+
ids, stream_cursor = get_stream_block_batch(
|
| 514 |
+
train_data,
|
| 515 |
+
batch_size=args.batch_size,
|
| 516 |
+
block_size=args.block_size,
|
| 517 |
+
device=device,
|
| 518 |
+
cursor=stream_cursor,
|
| 519 |
+
rank=rank,
|
| 520 |
+
world_size=stream_world_size,
|
| 521 |
+
)
|
| 522 |
+
else:
|
| 523 |
+
ids = get_block_batch(train_data, batch_size=args.batch_size, block_size=args.block_size, device=device)
|
| 524 |
+
state, t = corrupt_categorical_simplex(ids, tokenizer.vocab_size, lta_cfg)
|
| 525 |
+
logits = model(state, t)
|
| 526 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), ids.reshape(-1))
|
| 527 |
+
optimizer.zero_grad(set_to_none=True)
|
| 528 |
+
loss.backward()
|
| 529 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
| 530 |
+
optimizer.step()
|
| 531 |
+
if is_main_process(rank) and (step % args.log_interval == 0 or step == 1):
|
| 532 |
+
print(f"[lta] step={step} loss={loss.item():.4f} elapsed={time.time() - t0:.1f}s", flush=True)
|
| 533 |
+
if step % args.eval_interval == 0 or step == args.steps:
|
| 534 |
+
if is_ddp:
|
| 535 |
+
dist.barrier()
|
| 536 |
+
if is_main_process(rank):
|
| 537 |
+
val_loss = estimate_lta_loss(raw_model, val_data, args, lta_cfg, device, args.eval_iters)
|
| 538 |
+
sample = generate_lta(
|
| 539 |
+
raw_model,
|
| 540 |
+
tokenizer,
|
| 541 |
+
length=min(args.sample_len, args.block_size),
|
| 542 |
+
steps=args.decode_steps,
|
| 543 |
+
c_min=args.c_min,
|
| 544 |
+
c_max=args.decode_c_max,
|
| 545 |
+
temp=args.endpoint_temp,
|
| 546 |
+
device=device,
|
| 547 |
+
)
|
| 548 |
+
row = {"step": step, "train_loss": float(loss.item()), "val_loss": val_loss, **text_stats(sample)}
|
| 549 |
+
with log_path.open("a", encoding="utf-8") as f:
|
| 550 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 551 |
+
(out_dir / f"sample_step{step:05d}.txt").write_text(sample, encoding="utf-8")
|
| 552 |
+
save_checkpoint(out_dir / "latest.pt", raw_model, optimizer, step, cfg, {"mode": "lta", "data_mode": args.data_mode, "lta_config": asdict(lta_cfg), "tokenizer": tokenizer.to_json()})
|
| 553 |
+
if is_ddp:
|
| 554 |
+
dist.barrier()
|
| 555 |
+
if is_main_process(rank):
|
| 556 |
+
save_checkpoint(out_dir / f"step_{args.steps:05d}.pt", raw_model, optimizer, args.steps, cfg, {"mode": "lta", "data_mode": args.data_mode, "lta_config": asdict(lta_cfg), "tokenizer": tokenizer.to_json()})
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def main() -> None:
|
| 560 |
+
p = argparse.ArgumentParser()
|
| 561 |
+
p.add_argument("--mode", choices=["ar", "fully_coupled", "lta", "both"], default="both")
|
| 562 |
+
p.add_argument("--data_dir", default="experiments/nanogpt_tinyshakespeare_char/data")
|
| 563 |
+
p.add_argument("--out_dir", default="experiments/nanogpt_tinyshakespeare_char/runs/char_5k")
|
| 564 |
+
p.add_argument("--device", default="auto")
|
| 565 |
+
p.add_argument("--steps", type=int, default=5000)
|
| 566 |
+
p.add_argument("--data_mode", choices=["random", "stream"], default="random")
|
| 567 |
+
p.add_argument("--batch_size", type=int, default=64)
|
| 568 |
+
p.add_argument("--block_size", type=int, default=128)
|
| 569 |
+
p.add_argument("--n_layer", type=int, default=4)
|
| 570 |
+
p.add_argument("--n_head", type=int, default=4)
|
| 571 |
+
p.add_argument("--n_embd", type=int, default=128)
|
| 572 |
+
p.add_argument("--dropout", type=float, default=0.1)
|
| 573 |
+
p.add_argument("--lr", type=float, default=3e-4)
|
| 574 |
+
p.add_argument("--weight_decay", type=float, default=0.1)
|
| 575 |
+
p.add_argument("--grad_clip", type=float, default=1.0)
|
| 576 |
+
p.add_argument("--log_interval", type=int, default=100)
|
| 577 |
+
p.add_argument("--eval_interval", type=int, default=500)
|
| 578 |
+
p.add_argument("--eval_iters", type=int, default=20)
|
| 579 |
+
p.add_argument("--sample_len", type=int, default=512)
|
| 580 |
+
p.add_argument("--ar_temp", type=float, default=0.8)
|
| 581 |
+
p.add_argument("--c_min", type=float, default=1.0)
|
| 582 |
+
p.add_argument("--c_max", type=float, default=64.0)
|
| 583 |
+
p.add_argument("--decode_c_max", type=float, default=16.0)
|
| 584 |
+
p.add_argument("--endpoint_temp", type=float, default=1.3)
|
| 585 |
+
p.add_argument("--decode_steps", type=int, default=256)
|
| 586 |
+
p.add_argument("--t_mode", choices=["same", "independent"], default="same")
|
| 587 |
+
p.add_argument("--resume_path", default="")
|
| 588 |
+
p.add_argument("--seed", type=int, default=1337)
|
| 589 |
+
args = p.parse_args()
|
| 590 |
+
|
| 591 |
+
device, rank, local_rank, world_size, is_ddp = setup_distributed(args.device)
|
| 592 |
+
torch.manual_seed(args.seed + rank)
|
| 593 |
+
if is_main_process(rank):
|
| 594 |
+
print(f"[setup] device={device} rank={rank} world_size={world_size}", flush=True)
|
| 595 |
+
ensure_tinyshakespeare(Path(args.data_dir))
|
| 596 |
+
if is_ddp:
|
| 597 |
+
dist.barrier()
|
| 598 |
+
text, tokenizer, train_data, val_data = load_tinyshakespeare(Path(args.data_dir))
|
| 599 |
+
out_dir = Path(args.out_dir)
|
| 600 |
+
if is_main_process(rank):
|
| 601 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 602 |
+
(out_dir / "tokenizer.json").write_text(json.dumps(tokenizer.to_json(), ensure_ascii=False, indent=2), encoding="utf-8")
|
| 603 |
+
(out_dir / "args.json").write_text(json.dumps(vars(args), ensure_ascii=False, indent=2), encoding="utf-8")
|
| 604 |
+
print(f"[data] chars={len(text)} vocab={tokenizer.vocab_size} train={len(train_data)} val={len(val_data)}", flush=True)
|
| 605 |
+
if is_ddp:
|
| 606 |
+
dist.barrier()
|
| 607 |
+
|
| 608 |
+
try:
|
| 609 |
+
if args.mode in {"ar", "both"}:
|
| 610 |
+
train_ar(args, tokenizer, train_data, val_data, device, rank=rank, local_rank=local_rank, is_ddp=is_ddp)
|
| 611 |
+
if args.mode in {"fully_coupled", "lta", "both"}:
|
| 612 |
+
train_lta(args, tokenizer, train_data, val_data, device, rank=rank, local_rank=local_rank, is_ddp=is_ddp)
|
| 613 |
+
finally:
|
| 614 |
+
cleanup_distributed(is_ddp)
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
if __name__ == "__main__":
|
| 618 |
+
main()
|
LTA_openwebtext_dualt/logs/elfopt_4gpu_debug_20260513/lta_owt_fast10k_len1024_elfopt_muon_ema_ddit768x12_4gpu_5epoch_20260513.log
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
*****************************************
|
| 3 |
+
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
| 4 |
+
*****************************************
|
| 5 |
+
[rank0]:[W513 02:20:46.100148121 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
|
| 6 |
+
NCCL version 2.25.1+cuda12.8
|
| 7 |
+
[rank3]:[W513 02:20:46.141519466 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
|
| 8 |
+
[rank1]:[W513 02:20:46.143177080 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
|
| 9 |
+
[rank2]:[W513 02:20:46.172962616 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
|
| 10 |
+
{
|
| 11 |
+
"device": "cuda:0",
|
| 12 |
+
"rank": 0,
|
| 13 |
+
"world_size": 4,
|
| 14 |
+
"samples": "owt_cached_chunks:10904",
|
| 15 |
+
"vocab_size": 50257,
|
| 16 |
+
"tokenizer_vocab_size": 50257,
|
| 17 |
+
"save_dir": "runs/lta_owt_fast10k_len1024_elfopt_muon_ema_ddit768x12_4gpu_5epoch_20260513",
|
| 18 |
+
"batch_size": 8,
|
| 19 |
+
"grad_accum": 16,
|
| 20 |
+
"effective_batch_size": 512,
|
| 21 |
+
"global_batch_size": 512,
|
| 22 |
+
"lr_schedule": "constant_warmup",
|
| 23 |
+
"optimizer": "muon",
|
| 24 |
+
"warmup_steps": 11,
|
| 25 |
+
"min_lr": 0.0,
|
| 26 |
+
"weight_decay": 0.0,
|
| 27 |
+
"adamw_param_groups": "nanogpt",
|
| 28 |
+
"adam_beta1": 0.9,
|
| 29 |
+
"adam_beta2": 0.95,
|
| 30 |
+
"adam_eps": 1e-08,
|
| 31 |
+
"muon_momentum": 0.95,
|
| 32 |
+
"muon_ns_steps": 5,
|
| 33 |
+
"muon_update_scale": 1.0,
|
| 34 |
+
"ema_decay": 0.9999,
|
| 35 |
+
"ema_start_step": 0,
|
| 36 |
+
"model_type": "ddit",
|
| 37 |
+
"dual_t": true,
|
| 38 |
+
"corrupt_t_mode": "independent",
|
| 39 |
+
"corrupt_min_t": null,
|
| 40 |
+
"corrupt_max_t": null,
|
| 41 |
+
"prefix_block_prob": 0.0,
|
| 42 |
+
"prefix_block_len": 128,
|
| 43 |
+
"dirichlet_endpoint_mode": "categorical_dual_t",
|
| 44 |
+
"dirichlet_semantic_t_mode": "same",
|
| 45 |
+
"dirichlet_semantic_t_value": 0.0,
|
| 46 |
+
"categorical_wrong_from_full_vocab": true,
|
| 47 |
+
"categorical_wrong_from_batch_valid_tokens": false,
|
| 48 |
+
"mask_mixture_original_prob": 0.0,
|
| 49 |
+
"mask_mixture_lowk_prob": 0.0,
|
| 50 |
+
"mask_mixture_lowcorrupt_prob": 0.0,
|
| 51 |
+
"mask_mixture_block_prob": 0.0,
|
| 52 |
+
"mask_mixture_all_prob": 0.0,
|
| 53 |
+
"mask_mixture_lowk_clean_tokens": "1,2,4,8,16,32,64",
|
| 54 |
+
"mask_mixture_lowcorrupt_tokens": "1,2,4,8,16,32,64",
|
| 55 |
+
"mask_mixture_block_tokens": "64,128",
|
| 56 |
+
"simplex_bridge_sampler": "dirichlet",
|
| 57 |
+
"logistic_normal_sigma_min": 0.18,
|
| 58 |
+
"logistic_normal_sigma_max": 2.2,
|
| 59 |
+
"logistic_normal_tau_min": 0.65,
|
| 60 |
+
"logistic_normal_tau_max": 1.15,
|
| 61 |
+
"torch_compile": false,
|
| 62 |
+
"compile_mode": "max-autotune",
|
| 63 |
+
"state_format": "prob",
|
| 64 |
+
"target_loss": "hard_ce",
|
| 65 |
+
"meanflow_weight": 0.0,
|
| 66 |
+
"bridge_noise_init": "logistic_normal",
|
| 67 |
+
"noise_sigma": -1.0,
|
| 68 |
+
"wrap": true,
|
| 69 |
+
"wrap_mode": "stream",
|
| 70 |
+
"wrap_record_buffer_size": 200,
|
| 71 |
+
"owt_cached_chunks": true,
|
| 72 |
+
"owt_chunk_cache_dir": "/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k_fast10k",
|
| 73 |
+
"owt_chunk_cache_rebuild": false,
|
| 74 |
+
"owt_chunk_cache_write_batch": 4096,
|
| 75 |
+
"owt_exact_repeat_per_chunk": 0,
|
| 76 |
+
"online_chunk_shuffle": false,
|
| 77 |
+
"online_chunk_shuffle_buffer": 10000,
|
| 78 |
+
"openwebtext_split": "all",
|
| 79 |
+
"detokenizer": "auto",
|
| 80 |
+
"resolved_detokenizer": null,
|
| 81 |
+
"num_workers": 0,
|
| 82 |
+
"latest_every": 25,
|
| 83 |
+
"resume_path": ""
|
| 84 |
+
}
|
| 85 |
+
step=10 micro_steps=160 elapsed=35.3s lr=2.000000e-03 loss_all=10.7819 acc_all=0.5062 loss_corrupt=10.7918 acc_corrupt=0.3523 corrupt_frac=0.5581 loss=10.7918 loss_recon=10.7918 loss_meanflow=0.0000 mean_model_t=0.5021 mean_corrupt_t=0.5155 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4797 init_acc_corrupt=0.4881 init_gold_top10=0.5148 init_gold_top100=0.5444
|
| 86 |
+
step=20 micro_steps=320 elapsed=43.2s lr=2.000000e-03 loss_all=10.5714 acc_all=0.5577 loss_corrupt=10.6476 acc_corrupt=0.3828 corrupt_frac=0.5535 loss=10.6476 loss_recon=10.6476 loss_meanflow=0.0000 mean_model_t=0.4831 mean_corrupt_t=0.5005 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4958 init_acc_corrupt=0.4692 init_gold_top10=0.4982 init_gold_top100=0.5296
|
| 87 |
+
step=30 micro_steps=480 elapsed=49.2s lr=2.000000e-03 loss_all=10.2797 acc_all=0.5418 loss_corrupt=10.4402 acc_corrupt=0.3646 corrupt_frac=0.5515 loss=10.4402 loss_recon=10.4402 loss_meanflow=0.0000 mean_model_t=0.5005 mean_corrupt_t=0.4971 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.5053 init_acc_corrupt=0.4602 init_gold_top10=0.4890 init_gold_top100=0.5193
|
| 88 |
+
step=40 micro_steps=640 elapsed=47.3s lr=2.000000e-03 loss_all=9.9466 acc_all=0.5278 loss_corrupt=10.1895 acc_corrupt=0.3629 corrupt_frac=0.5560 loss=10.1895 loss_recon=10.1895 loss_meanflow=0.0000 mean_model_t=0.4889 mean_corrupt_t=0.5037 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4913 init_acc_corrupt=0.4763 init_gold_top10=0.5031 init_gold_top100=0.5324
|
| 89 |
+
step=50 micro_steps=800 elapsed=48.1s lr=2.000000e-03 loss_all=9.5849 acc_all=0.5083 loss_corrupt=9.9274 acc_corrupt=0.3452 corrupt_frac=0.5483 loss=9.9274 loss_recon=9.9274 loss_meanflow=0.0000 mean_model_t=0.5005 mean_corrupt_t=0.5071 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4913 init_acc_corrupt=0.4737 init_gold_top10=0.5031 init_gold_top100=0.5327
|
| 90 |
+
step=60 micro_steps=960 elapsed=52.1s lr=2.000000e-03 loss_all=9.2104 acc_all=0.4910 loss_corrupt=9.6379 acc_corrupt=0.3375 corrupt_frac=0.5677 loss=9.6379 loss_recon=9.6379 loss_meanflow=0.0000 mean_model_t=0.5078 mean_corrupt_t=0.4997 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4952 init_acc_corrupt=0.4702 init_gold_top10=0.4995 init_gold_top100=0.5282
|
| 91 |
+
step=70 micro_steps=1120 elapsed=49.0s lr=2.000000e-03 loss_all=8.7828 acc_all=0.4820 loss_corrupt=9.3226 acc_corrupt=0.3293 corrupt_frac=0.5563 loss=9.3226 loss_recon=9.3226 loss_meanflow=0.0000 mean_model_t=0.5141 mean_corrupt_t=0.5078 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4938 init_acc_corrupt=0.4720 init_gold_top10=0.5007 init_gold_top100=0.5290
|
| 92 |
+
step=80 micro_steps=1280 elapsed=51.9s lr=2.000000e-03 loss_all=8.3273 acc_all=0.4771 loss_corrupt=8.9524 acc_corrupt=0.3331 corrupt_frac=0.5579 loss=8.9524 loss_recon=8.9524 loss_meanflow=0.0000 mean_model_t=0.5173 mean_corrupt_t=0.5132 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4842 init_acc_corrupt=0.4812 init_gold_top10=0.5102 init_gold_top100=0.5391
|
| 93 |
+
step=90 micro_steps=1440 elapsed=49.3s lr=2.000000e-03 loss_all=7.8580 acc_all=0.4804 loss_corrupt=8.5915 acc_corrupt=0.3368 corrupt_frac=0.5610 loss=8.5915 loss_recon=8.5915 loss_meanflow=0.0000 mean_model_t=0.5097 mean_corrupt_t=0.5138 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4881 init_acc_corrupt=0.4782 init_gold_top10=0.5062 init_gold_top100=0.5356
|
| 94 |
+
step=100 micro_steps=1600 elapsed=49.2s lr=2.000000e-03 loss_all=7.3653 acc_all=0.4879 loss_corrupt=8.2388 acc_corrupt=0.3383 corrupt_frac=0.5443 loss=8.2388 loss_recon=8.2388 loss_meanflow=0.0000 mean_model_t=0.4959 mean_corrupt_t=0.5087 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4917 init_acc_corrupt=0.4737 init_gold_top10=0.5032 init_gold_top100=0.5311
|
LTA_openwebtext_dualt/logs/elfopt_4gpu_debug_20260513/lta_owt_fast10k_len1024_elfopt_muon_ema_ddit768x12_4gpu_5epoch_20260513_trace.log
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"out_json": "docs/lta_samples/metrics_20260513/lta_owt_fast10k_len1024_elfopt_muon_ema_ddit768x12_4gpu_5epoch_20260513/trace_latest_ema_steps64_c48_t1p45.json", "records": 10, "step": 107}
|
LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/infer_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_step_0010000_t1p45.log
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[watch-classic] 2026-05-21_01:37:41 infer runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0010000.pt -> docs/lta_samples/metrics_20260520/lm1b_classic_repro_every10k_normal_steps_state_t1p45_c1024_n1024/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0010000
|
| 2 |
+
[ckpt] runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0010000.pt step=10000
|
| 3 |
+
[decode] steps128_c1024_t1p45 generated 16/1024
|
| 4 |
+
[decode] steps128_c1024_t1p45 generated 32/1024
|
| 5 |
+
[decode] steps128_c1024_t1p45 generated 48/1024
|
| 6 |
+
[decode] steps128_c1024_t1p45 generated 64/1024
|
| 7 |
+
[decode] steps128_c1024_t1p45 generated 80/1024
|
| 8 |
+
[decode] steps128_c1024_t1p45 generated 96/1024
|
| 9 |
+
[decode] steps128_c1024_t1p45 generated 112/1024
|
| 10 |
+
[decode] steps128_c1024_t1p45 generated 128/1024
|
| 11 |
+
[decode] steps128_c1024_t1p45 generated 144/1024
|
| 12 |
+
[decode] steps128_c1024_t1p45 generated 160/1024
|
| 13 |
+
[decode] steps128_c1024_t1p45 generated 176/1024
|
| 14 |
+
[decode] steps128_c1024_t1p45 generated 192/1024
|
| 15 |
+
[decode] steps128_c1024_t1p45 generated 208/1024
|
| 16 |
+
[decode] steps128_c1024_t1p45 generated 224/1024
|
| 17 |
+
[decode] steps128_c1024_t1p45 generated 240/1024
|
| 18 |
+
[decode] steps128_c1024_t1p45 generated 256/1024
|
| 19 |
+
[decode] steps128_c1024_t1p45 generated 272/1024
|
| 20 |
+
[decode] steps128_c1024_t1p45 generated 288/1024
|
| 21 |
+
[decode] steps128_c1024_t1p45 generated 304/1024
|
| 22 |
+
[decode] steps128_c1024_t1p45 generated 320/1024
|
| 23 |
+
[decode] steps128_c1024_t1p45 generated 336/1024
|
| 24 |
+
[decode] steps128_c1024_t1p45 generated 352/1024
|
| 25 |
+
[decode] steps128_c1024_t1p45 generated 368/1024
|
| 26 |
+
[decode] steps128_c1024_t1p45 generated 384/1024
|
| 27 |
+
[decode] steps128_c1024_t1p45 generated 400/1024
|
| 28 |
+
[decode] steps128_c1024_t1p45 generated 416/1024
|
| 29 |
+
[decode] steps128_c1024_t1p45 generated 432/1024
|
| 30 |
+
[decode] steps128_c1024_t1p45 generated 448/1024
|
| 31 |
+
[decode] steps128_c1024_t1p45 generated 464/1024
|
| 32 |
+
[decode] steps128_c1024_t1p45 generated 480/1024
|
| 33 |
+
[decode] steps128_c1024_t1p45 generated 496/1024
|
| 34 |
+
[decode] steps128_c1024_t1p45 generated 512/1024
|
| 35 |
+
[decode] steps128_c1024_t1p45 generated 528/1024
|
| 36 |
+
[decode] steps128_c1024_t1p45 generated 544/1024
|
| 37 |
+
[decode] steps128_c1024_t1p45 generated 560/1024
|
| 38 |
+
[decode] steps128_c1024_t1p45 generated 576/1024
|
| 39 |
+
[decode] steps128_c1024_t1p45 generated 592/1024
|
| 40 |
+
[decode] steps128_c1024_t1p45 generated 608/1024
|
| 41 |
+
[decode] steps128_c1024_t1p45 generated 624/1024
|
| 42 |
+
[decode] steps128_c1024_t1p45 generated 640/1024
|
| 43 |
+
[decode] steps128_c1024_t1p45 generated 656/1024
|
| 44 |
+
[decode] steps128_c1024_t1p45 generated 672/1024
|
| 45 |
+
[decode] steps128_c1024_t1p45 generated 688/1024
|
| 46 |
+
[decode] steps128_c1024_t1p45 generated 704/1024
|
| 47 |
+
[decode] steps128_c1024_t1p45 generated 720/1024
|
| 48 |
+
[decode] steps128_c1024_t1p45 generated 736/1024
|
| 49 |
+
[decode] steps128_c1024_t1p45 generated 752/1024
|
| 50 |
+
[decode] steps128_c1024_t1p45 generated 768/1024
|
| 51 |
+
[decode] steps128_c1024_t1p45 generated 784/1024
|
| 52 |
+
[decode] steps128_c1024_t1p45 generated 800/1024
|
| 53 |
+
[decode] steps128_c1024_t1p45 generated 816/1024
|
| 54 |
+
[decode] steps128_c1024_t1p45 generated 832/1024
|
| 55 |
+
[decode] steps128_c1024_t1p45 generated 848/1024
|
| 56 |
+
[decode] steps128_c1024_t1p45 generated 864/1024
|
| 57 |
+
[decode] steps128_c1024_t1p45 generated 880/1024
|
| 58 |
+
[decode] steps128_c1024_t1p45 generated 896/1024
|
| 59 |
+
[decode] steps128_c1024_t1p45 generated 912/1024
|
| 60 |
+
[decode] steps128_c1024_t1p45 generated 928/1024
|
| 61 |
+
[decode] steps128_c1024_t1p45 generated 944/1024
|
| 62 |
+
[decode] steps128_c1024_t1p45 generated 960/1024
|
| 63 |
+
[decode] steps128_c1024_t1p45 generated 976/1024
|
| 64 |
+
[decode] steps128_c1024_t1p45 generated 992/1024
|
| 65 |
+
[decode] steps128_c1024_t1p45 generated 1008/1024
|
| 66 |
+
[decode] steps128_c1024_t1p45 generated 1024/1024
|
| 67 |
+
[summary] {"name": "steps128_c1024_t1p45", "step": 10000, "decode_steps": 128, "concentration_max": 1024.0, "raw_genppl": 30.13676440721101, "stripped_genppl": 34.68526771869161, "sample_entropy": 3.465849114229341, "distinct_1": 0.03083038330078125, "distinct_2": 0.19376691683070865, "top_token_mass": 0.232330322265625, "raw_kept": 1024, "stripped_kept": 1024}
|
| 68 |
+
[watch-classic] 2026-05-21_01:46:53 done step_0010000
|
LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/infer_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_step_0020000_t1p45.log
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[watch-classic] 2026-05-21_04:16:54 infer runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0020000.pt -> docs/lta_samples/metrics_20260520/lm1b_classic_repro_every10k_normal_steps_state_t1p45_c1024_n1024/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0020000
|
| 2 |
+
[ckpt] runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0020000.pt step=20000
|
| 3 |
+
[decode] steps128_c1024_t1p45 generated 16/1024
|
| 4 |
+
[decode] steps128_c1024_t1p45 generated 32/1024
|
| 5 |
+
[decode] steps128_c1024_t1p45 generated 48/1024
|
| 6 |
+
[decode] steps128_c1024_t1p45 generated 64/1024
|
| 7 |
+
[decode] steps128_c1024_t1p45 generated 80/1024
|
| 8 |
+
[decode] steps128_c1024_t1p45 generated 96/1024
|
| 9 |
+
[decode] steps128_c1024_t1p45 generated 112/1024
|
| 10 |
+
[decode] steps128_c1024_t1p45 generated 128/1024
|
| 11 |
+
[decode] steps128_c1024_t1p45 generated 144/1024
|
| 12 |
+
[decode] steps128_c1024_t1p45 generated 160/1024
|
| 13 |
+
[decode] steps128_c1024_t1p45 generated 176/1024
|
| 14 |
+
[decode] steps128_c1024_t1p45 generated 192/1024
|
| 15 |
+
[decode] steps128_c1024_t1p45 generated 208/1024
|
| 16 |
+
[decode] steps128_c1024_t1p45 generated 224/1024
|
| 17 |
+
[decode] steps128_c1024_t1p45 generated 240/1024
|
| 18 |
+
[decode] steps128_c1024_t1p45 generated 256/1024
|
| 19 |
+
[decode] steps128_c1024_t1p45 generated 272/1024
|
| 20 |
+
[decode] steps128_c1024_t1p45 generated 288/1024
|
| 21 |
+
[decode] steps128_c1024_t1p45 generated 304/1024
|
| 22 |
+
[decode] steps128_c1024_t1p45 generated 320/1024
|
| 23 |
+
[decode] steps128_c1024_t1p45 generated 336/1024
|
| 24 |
+
[decode] steps128_c1024_t1p45 generated 352/1024
|
| 25 |
+
[decode] steps128_c1024_t1p45 generated 368/1024
|
| 26 |
+
[decode] steps128_c1024_t1p45 generated 384/1024
|
| 27 |
+
[decode] steps128_c1024_t1p45 generated 400/1024
|
| 28 |
+
[decode] steps128_c1024_t1p45 generated 416/1024
|
| 29 |
+
[decode] steps128_c1024_t1p45 generated 432/1024
|
| 30 |
+
[decode] steps128_c1024_t1p45 generated 448/1024
|
| 31 |
+
[decode] steps128_c1024_t1p45 generated 464/1024
|
| 32 |
+
[decode] steps128_c1024_t1p45 generated 480/1024
|
| 33 |
+
[decode] steps128_c1024_t1p45 generated 496/1024
|
| 34 |
+
[decode] steps128_c1024_t1p45 generated 512/1024
|
| 35 |
+
[decode] steps128_c1024_t1p45 generated 528/1024
|
| 36 |
+
[decode] steps128_c1024_t1p45 generated 544/1024
|
| 37 |
+
[decode] steps128_c1024_t1p45 generated 560/1024
|
| 38 |
+
[decode] steps128_c1024_t1p45 generated 576/1024
|
| 39 |
+
[decode] steps128_c1024_t1p45 generated 592/1024
|
| 40 |
+
[decode] steps128_c1024_t1p45 generated 608/1024
|
| 41 |
+
[decode] steps128_c1024_t1p45 generated 624/1024
|
| 42 |
+
[decode] steps128_c1024_t1p45 generated 640/1024
|
| 43 |
+
[decode] steps128_c1024_t1p45 generated 656/1024
|
| 44 |
+
[decode] steps128_c1024_t1p45 generated 672/1024
|
| 45 |
+
[decode] steps128_c1024_t1p45 generated 688/1024
|
| 46 |
+
[decode] steps128_c1024_t1p45 generated 704/1024
|
| 47 |
+
[decode] steps128_c1024_t1p45 generated 720/1024
|
| 48 |
+
[decode] steps128_c1024_t1p45 generated 736/1024
|
| 49 |
+
[decode] steps128_c1024_t1p45 generated 752/1024
|
| 50 |
+
[decode] steps128_c1024_t1p45 generated 768/1024
|
| 51 |
+
[decode] steps128_c1024_t1p45 generated 784/1024
|
| 52 |
+
[decode] steps128_c1024_t1p45 generated 800/1024
|
| 53 |
+
[decode] steps128_c1024_t1p45 generated 816/1024
|
| 54 |
+
[decode] steps128_c1024_t1p45 generated 832/1024
|
| 55 |
+
[decode] steps128_c1024_t1p45 generated 848/1024
|
| 56 |
+
[decode] steps128_c1024_t1p45 generated 864/1024
|
| 57 |
+
[decode] steps128_c1024_t1p45 generated 880/1024
|
| 58 |
+
[decode] steps128_c1024_t1p45 generated 896/1024
|
| 59 |
+
[decode] steps128_c1024_t1p45 generated 912/1024
|
| 60 |
+
[decode] steps128_c1024_t1p45 generated 928/1024
|
| 61 |
+
[decode] steps128_c1024_t1p45 generated 944/1024
|
| 62 |
+
[decode] steps128_c1024_t1p45 generated 960/1024
|
| 63 |
+
[decode] steps128_c1024_t1p45 generated 976/1024
|
| 64 |
+
[decode] steps128_c1024_t1p45 generated 992/1024
|
| 65 |
+
[decode] steps128_c1024_t1p45 generated 1008/1024
|
| 66 |
+
[decode] steps128_c1024_t1p45 generated 1024/1024
|
| 67 |
+
[summary] {"name": "steps128_c1024_t1p45", "step": 20000, "decode_steps": 128, "concentration_max": 1024.0, "raw_genppl": 46.858096679153796, "stripped_genppl": 50.24137870832906, "sample_entropy": 3.638899175986386, "distinct_1": 0.0443572998046875, "distinct_2": 0.28157295767716534, "top_token_mass": 0.1688232421875, "raw_kept": 1024, "stripped_kept": 1024}
|
| 68 |
+
[watch-classic] 2026-05-21_04:26:05 done step_0020000
|
LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/infer_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_step_0030000_t1p45.log
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[watch-classic] 2026-05-21_06:55:06 infer runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0030000.pt -> docs/lta_samples/metrics_20260520/lm1b_classic_repro_every10k_normal_steps_state_t1p45_c1024_n1024/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0030000
|
| 2 |
+
[ckpt] runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0030000.pt step=30000
|
| 3 |
+
[decode] steps128_c1024_t1p45 generated 16/1024
|
| 4 |
+
[decode] steps128_c1024_t1p45 generated 32/1024
|
| 5 |
+
[decode] steps128_c1024_t1p45 generated 48/1024
|
| 6 |
+
[decode] steps128_c1024_t1p45 generated 64/1024
|
| 7 |
+
[decode] steps128_c1024_t1p45 generated 80/1024
|
| 8 |
+
[decode] steps128_c1024_t1p45 generated 96/1024
|
| 9 |
+
[decode] steps128_c1024_t1p45 generated 112/1024
|
| 10 |
+
[decode] steps128_c1024_t1p45 generated 128/1024
|
| 11 |
+
[decode] steps128_c1024_t1p45 generated 144/1024
|
| 12 |
+
[decode] steps128_c1024_t1p45 generated 160/1024
|
| 13 |
+
[decode] steps128_c1024_t1p45 generated 176/1024
|
| 14 |
+
[decode] steps128_c1024_t1p45 generated 192/1024
|
| 15 |
+
[decode] steps128_c1024_t1p45 generated 208/1024
|
| 16 |
+
[decode] steps128_c1024_t1p45 generated 224/1024
|
| 17 |
+
[decode] steps128_c1024_t1p45 generated 240/1024
|
| 18 |
+
[decode] steps128_c1024_t1p45 generated 256/1024
|
| 19 |
+
[decode] steps128_c1024_t1p45 generated 272/1024
|
| 20 |
+
[decode] steps128_c1024_t1p45 generated 288/1024
|
| 21 |
+
[decode] steps128_c1024_t1p45 generated 304/1024
|
| 22 |
+
[decode] steps128_c1024_t1p45 generated 320/1024
|
| 23 |
+
[decode] steps128_c1024_t1p45 generated 336/1024
|
| 24 |
+
[decode] steps128_c1024_t1p45 generated 352/1024
|
| 25 |
+
[decode] steps128_c1024_t1p45 generated 368/1024
|
| 26 |
+
[decode] steps128_c1024_t1p45 generated 384/1024
|
| 27 |
+
[decode] steps128_c1024_t1p45 generated 400/1024
|
| 28 |
+
[decode] steps128_c1024_t1p45 generated 416/1024
|
| 29 |
+
[decode] steps128_c1024_t1p45 generated 432/1024
|
| 30 |
+
[decode] steps128_c1024_t1p45 generated 448/1024
|
| 31 |
+
[decode] steps128_c1024_t1p45 generated 464/1024
|
| 32 |
+
[decode] steps128_c1024_t1p45 generated 480/1024
|
| 33 |
+
[decode] steps128_c1024_t1p45 generated 496/1024
|
| 34 |
+
[decode] steps128_c1024_t1p45 generated 512/1024
|
| 35 |
+
[decode] steps128_c1024_t1p45 generated 528/1024
|
| 36 |
+
[decode] steps128_c1024_t1p45 generated 544/1024
|
| 37 |
+
[decode] steps128_c1024_t1p45 generated 560/1024
|
| 38 |
+
[decode] steps128_c1024_t1p45 generated 576/1024
|
| 39 |
+
[decode] steps128_c1024_t1p45 generated 592/1024
|
| 40 |
+
[decode] steps128_c1024_t1p45 generated 608/1024
|
| 41 |
+
[decode] steps128_c1024_t1p45 generated 624/1024
|
| 42 |
+
[decode] steps128_c1024_t1p45 generated 640/1024
|
| 43 |
+
[decode] steps128_c1024_t1p45 generated 656/1024
|
| 44 |
+
[decode] steps128_c1024_t1p45 generated 672/1024
|
| 45 |
+
[decode] steps128_c1024_t1p45 generated 688/1024
|
| 46 |
+
[decode] steps128_c1024_t1p45 generated 704/1024
|
| 47 |
+
[decode] steps128_c1024_t1p45 generated 720/1024
|
| 48 |
+
[decode] steps128_c1024_t1p45 generated 736/1024
|
| 49 |
+
[decode] steps128_c1024_t1p45 generated 752/1024
|
| 50 |
+
[decode] steps128_c1024_t1p45 generated 768/1024
|
| 51 |
+
[decode] steps128_c1024_t1p45 generated 784/1024
|
| 52 |
+
[decode] steps128_c1024_t1p45 generated 800/1024
|
| 53 |
+
[decode] steps128_c1024_t1p45 generated 816/1024
|
| 54 |
+
[decode] steps128_c1024_t1p45 generated 832/1024
|
| 55 |
+
[decode] steps128_c1024_t1p45 generated 848/1024
|
| 56 |
+
[decode] steps128_c1024_t1p45 generated 864/1024
|
| 57 |
+
[decode] steps128_c1024_t1p45 generated 880/1024
|
| 58 |
+
[decode] steps128_c1024_t1p45 generated 896/1024
|
| 59 |
+
[decode] steps128_c1024_t1p45 generated 912/1024
|
| 60 |
+
[decode] steps128_c1024_t1p45 generated 928/1024
|
| 61 |
+
[decode] steps128_c1024_t1p45 generated 944/1024
|
| 62 |
+
[decode] steps128_c1024_t1p45 generated 960/1024
|
| 63 |
+
[decode] steps128_c1024_t1p45 generated 976/1024
|
| 64 |
+
[decode] steps128_c1024_t1p45 generated 992/1024
|
| 65 |
+
[decode] steps128_c1024_t1p45 generated 1008/1024
|
| 66 |
+
[decode] steps128_c1024_t1p45 generated 1024/1024
|
| 67 |
+
[summary] {"name": "steps128_c1024_t1p45", "step": 30000, "decode_steps": 128, "concentration_max": 1024.0, "raw_genppl": 42.11245193528327, "stripped_genppl": 57.07543476214156, "sample_entropy": 4.026958025870078, "distinct_1": 0.04402923583984375, "distinct_2": 0.3001584030511811, "top_token_mass": 0.0693817138671875, "raw_kept": 1024, "stripped_kept": 1024}
|
| 68 |
+
[watch-classic] 2026-05-21_07:04:13 done step_0030000
|
LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/processed_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_steps128_c1024_t1p45_n1024.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0010000.pt
|
| 2 |
+
runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0020000.pt
|
| 3 |
+
runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0030000.pt
|
| 4 |
+
runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0040000.pt
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/activate.bat
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@REM Copyright (c) 2020-202x The virtualenv developers
|
| 2 |
+
@REM
|
| 3 |
+
@REM Permission is hereby granted, free of charge, to any person obtaining
|
| 4 |
+
@REM a copy of this software and associated documentation files (the
|
| 5 |
+
@REM "Software"), to deal in the Software without restriction, including
|
| 6 |
+
@REM without limitation the rights to use, copy, modify, merge, publish,
|
| 7 |
+
@REM distribute, sublicense, and/or sell copies of the Software, and to
|
| 8 |
+
@REM permit persons to whom the Software is furnished to do so, subject to
|
| 9 |
+
@REM the following conditions:
|
| 10 |
+
@REM
|
| 11 |
+
@REM The above copyright notice and this permission notice shall be
|
| 12 |
+
@REM included in all copies or substantial portions of the Software.
|
| 13 |
+
@REM
|
| 14 |
+
@REM THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
| 15 |
+
@REM EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 16 |
+
@REM MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
| 17 |
+
@REM NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
| 18 |
+
@REM LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
| 19 |
+
@REM OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
| 20 |
+
@REM WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 21 |
+
|
| 22 |
+
@REM This file is UTF-8 encoded, so we need to update the current code page while executing it
|
| 23 |
+
@for /f "tokens=2 delims=:." %%a in ('"%SystemRoot%\System32\chcp.com"') do @set _OLD_CODEPAGE=%%a
|
| 24 |
+
|
| 25 |
+
@if defined _OLD_CODEPAGE (
|
| 26 |
+
"%SystemRoot%\System32\chcp.com" 65001 > nul
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
@for %%i in ("/e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv") do @set "VIRTUAL_ENV=%%~fi"
|
| 30 |
+
|
| 31 |
+
@set "VIRTUAL_ENV_PROMPT="
|
| 32 |
+
@if NOT DEFINED VIRTUAL_ENV_PROMPT (
|
| 33 |
+
@for %%d in ("%VIRTUAL_ENV%") do @set "VIRTUAL_ENV_PROMPT=%%~nxd"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
@if defined _OLD_VIRTUAL_PROMPT (
|
| 37 |
+
@set "PROMPT=%_OLD_VIRTUAL_PROMPT%"
|
| 38 |
+
) else (
|
| 39 |
+
@if not defined PROMPT (
|
| 40 |
+
@set "PROMPT=$P$G"
|
| 41 |
+
)
|
| 42 |
+
@if not defined VIRTUAL_ENV_DISABLE_PROMPT (
|
| 43 |
+
@set "_OLD_VIRTUAL_PROMPT=%PROMPT%"
|
| 44 |
+
)
|
| 45 |
+
)
|
| 46 |
+
@if not defined VIRTUAL_ENV_DISABLE_PROMPT (
|
| 47 |
+
@set "PROMPT=(%VIRTUAL_ENV_PROMPT%) %PROMPT%"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
@REM Don't use () to avoid problems with them in %PATH%
|
| 51 |
+
@if defined _OLD_VIRTUAL_PYTHONHOME @goto ENDIFVHOME
|
| 52 |
+
@set "_OLD_VIRTUAL_PYTHONHOME=%PYTHONHOME%"
|
| 53 |
+
:ENDIFVHOME
|
| 54 |
+
|
| 55 |
+
@set PYTHONHOME=
|
| 56 |
+
|
| 57 |
+
@REM if defined _OLD_VIRTUAL_PATH (
|
| 58 |
+
@if not defined _OLD_VIRTUAL_PATH @goto ENDIFVPATH1
|
| 59 |
+
@set "PATH=%_OLD_VIRTUAL_PATH%"
|
| 60 |
+
:ENDIFVPATH1
|
| 61 |
+
@REM ) else (
|
| 62 |
+
@if defined _OLD_VIRTUAL_PATH @goto ENDIFVPATH2
|
| 63 |
+
@set "_OLD_VIRTUAL_PATH=%PATH%"
|
| 64 |
+
:ENDIFVPATH2
|
| 65 |
+
|
| 66 |
+
@set "PATH=%VIRTUAL_ENV%\bin;%PATH%"
|
| 67 |
+
|
| 68 |
+
@if defined _OLD_CODEPAGE (
|
| 69 |
+
"%SystemRoot%\System32\chcp.com" %_OLD_CODEPAGE% > nul
|
| 70 |
+
@set _OLD_CODEPAGE=
|
| 71 |
+
)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/activate.fish
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2020-202x The virtualenv developers
|
| 2 |
+
#
|
| 3 |
+
# Permission is hereby granted, free of charge, to any person obtaining
|
| 4 |
+
# a copy of this software and associated documentation files (the
|
| 5 |
+
# "Software"), to deal in the Software without restriction, including
|
| 6 |
+
# without limitation the rights to use, copy, modify, merge, publish,
|
| 7 |
+
# distribute, sublicense, and/or sell copies of the Software, and to
|
| 8 |
+
# permit persons to whom the Software is furnished to do so, subject to
|
| 9 |
+
# the following conditions:
|
| 10 |
+
#
|
| 11 |
+
# The above copyright notice and this permission notice shall be
|
| 12 |
+
# included in all copies or substantial portions of the Software.
|
| 13 |
+
#
|
| 14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
| 15 |
+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 16 |
+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
| 17 |
+
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
| 18 |
+
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
| 19 |
+
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
| 20 |
+
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 21 |
+
|
| 22 |
+
# This file must be used using `source bin/activate.fish` *within a running fish ( http://fishshell.com ) session*.
|
| 23 |
+
# Do not run it directly.
|
| 24 |
+
|
| 25 |
+
function _bashify_path -d "Converts a fish path to something bash can recognize"
|
| 26 |
+
set fishy_path $argv
|
| 27 |
+
set bashy_path $fishy_path[1]
|
| 28 |
+
for path_part in $fishy_path[2..-1]
|
| 29 |
+
set bashy_path "$bashy_path:$path_part"
|
| 30 |
+
end
|
| 31 |
+
echo $bashy_path
|
| 32 |
+
end
|
| 33 |
+
|
| 34 |
+
function _fishify_path -d "Converts a bash path to something fish can recognize"
|
| 35 |
+
echo $argv | tr ':' '\n'
|
| 36 |
+
end
|
| 37 |
+
|
| 38 |
+
function deactivate -d 'Exit virtualenv mode and return to the normal environment.'
|
| 39 |
+
# reset old environment variables
|
| 40 |
+
if test -n "$_OLD_VIRTUAL_PATH"
|
| 41 |
+
# https://github.com/fish-shell/fish-shell/issues/436 altered PATH handling
|
| 42 |
+
if test (string sub -s 1 -l 1 $FISH_VERSION) -lt 3
|
| 43 |
+
set -gx PATH (_fishify_path "$_OLD_VIRTUAL_PATH")
|
| 44 |
+
else
|
| 45 |
+
set -gx PATH $_OLD_VIRTUAL_PATH
|
| 46 |
+
end
|
| 47 |
+
set -e _OLD_VIRTUAL_PATH
|
| 48 |
+
end
|
| 49 |
+
|
| 50 |
+
if test -n "$_OLD_VIRTUAL_PYTHONHOME"
|
| 51 |
+
set -gx PYTHONHOME "$_OLD_VIRTUAL_PYTHONHOME"
|
| 52 |
+
set -e _OLD_VIRTUAL_PYTHONHOME
|
| 53 |
+
end
|
| 54 |
+
|
| 55 |
+
if test -n "$_OLD_FISH_PROMPT_OVERRIDE"
|
| 56 |
+
and functions -q _old_fish_prompt
|
| 57 |
+
# Set an empty local `$fish_function_path` to allow the removal of `fish_prompt` using `functions -e`.
|
| 58 |
+
set -l fish_function_path
|
| 59 |
+
|
| 60 |
+
# Erase virtualenv's `fish_prompt` and restore the original.
|
| 61 |
+
functions -e fish_prompt
|
| 62 |
+
functions -c _old_fish_prompt fish_prompt
|
| 63 |
+
functions -e _old_fish_prompt
|
| 64 |
+
set -e _OLD_FISH_PROMPT_OVERRIDE
|
| 65 |
+
end
|
| 66 |
+
|
| 67 |
+
set -e VIRTUAL_ENV
|
| 68 |
+
set -e VIRTUAL_ENV_PROMPT
|
| 69 |
+
|
| 70 |
+
if test "$argv[1]" != 'nondestructive'
|
| 71 |
+
# Self-destruct!
|
| 72 |
+
functions -e pydoc
|
| 73 |
+
functions -e deactivate
|
| 74 |
+
functions -e _bashify_path
|
| 75 |
+
functions -e _fishify_path
|
| 76 |
+
end
|
| 77 |
+
end
|
| 78 |
+
|
| 79 |
+
# Unset irrelevant variables.
|
| 80 |
+
deactivate nondestructive
|
| 81 |
+
|
| 82 |
+
set -gx VIRTUAL_ENV '/e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv'
|
| 83 |
+
|
| 84 |
+
# https://github.com/fish-shell/fish-shell/issues/436 altered PATH handling
|
| 85 |
+
if test (string sub -s 1 -l 1 $FISH_VERSION) -lt 3
|
| 86 |
+
set -gx _OLD_VIRTUAL_PATH (_bashify_path $PATH)
|
| 87 |
+
else
|
| 88 |
+
set -gx _OLD_VIRTUAL_PATH $PATH
|
| 89 |
+
end
|
| 90 |
+
set -gx PATH "$VIRTUAL_ENV"'/bin' $PATH
|
| 91 |
+
|
| 92 |
+
# Prompt override provided?
|
| 93 |
+
# If not, just use the environment name.
|
| 94 |
+
if test -n ''
|
| 95 |
+
set -gx VIRTUAL_ENV_PROMPT ''
|
| 96 |
+
else
|
| 97 |
+
set -gx VIRTUAL_ENV_PROMPT (basename "$VIRTUAL_ENV")
|
| 98 |
+
end
|
| 99 |
+
|
| 100 |
+
# Unset `$PYTHONHOME` if set.
|
| 101 |
+
if set -q PYTHONHOME
|
| 102 |
+
set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME
|
| 103 |
+
set -e PYTHONHOME
|
| 104 |
+
end
|
| 105 |
+
|
| 106 |
+
function pydoc
|
| 107 |
+
python -m pydoc $argv
|
| 108 |
+
end
|
| 109 |
+
|
| 110 |
+
if test -z "$VIRTUAL_ENV_DISABLE_PROMPT"
|
| 111 |
+
# Copy the current `fish_prompt` function as `_old_fish_prompt`.
|
| 112 |
+
functions -c fish_prompt _old_fish_prompt
|
| 113 |
+
|
| 114 |
+
function fish_prompt
|
| 115 |
+
# Run the user's prompt first; it might depend on (pipe)status.
|
| 116 |
+
set -l prompt (_old_fish_prompt)
|
| 117 |
+
|
| 118 |
+
printf '(%s) ' $VIRTUAL_ENV_PROMPT
|
| 119 |
+
|
| 120 |
+
string join -- \n $prompt # handle multi-line prompts
|
| 121 |
+
end
|
| 122 |
+
|
| 123 |
+
set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV"
|
| 124 |
+
end
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/f2py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import sys
|
| 4 |
+
from numpy.f2py.f2py2e import main
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
if sys.argv[0].endswith("-script.pyw"):
|
| 7 |
+
sys.argv[0] = sys.argv[0][:-11]
|
| 8 |
+
elif sys.argv[0].endswith(".exe"):
|
| 9 |
+
sys.argv[0] = sys.argv[0][:-4]
|
| 10 |
+
sys.exit(main())
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/pydoc.bat
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@REM Copyright (c) 2020-202x The virtualenv developers
|
| 2 |
+
@REM
|
| 3 |
+
@REM Permission is hereby granted, free of charge, to any person obtaining
|
| 4 |
+
@REM a copy of this software and associated documentation files (the
|
| 5 |
+
@REM "Software"), to deal in the Software without restriction, including
|
| 6 |
+
@REM without limitation the rights to use, copy, modify, merge, publish,
|
| 7 |
+
@REM distribute, sublicense, and/or sell copies of the Software, and to
|
| 8 |
+
@REM permit persons to whom the Software is furnished to do so, subject to
|
| 9 |
+
@REM the following conditions:
|
| 10 |
+
@REM
|
| 11 |
+
@REM The above copyright notice and this permission notice shall be
|
| 12 |
+
@REM included in all copies or substantial portions of the Software.
|
| 13 |
+
@REM
|
| 14 |
+
@REM THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
| 15 |
+
@REM EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 16 |
+
@REM MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
| 17 |
+
@REM NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
| 18 |
+
@REM LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
| 19 |
+
@REM OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
| 20 |
+
@REM WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 21 |
+
|
| 22 |
+
python.exe -m pydoc %*
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/audio_utils.py
ADDED
|
@@ -0,0 +1,1254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
|
| 16 |
+
and remove unnecessary dependencies.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import base64
|
| 20 |
+
import importlib
|
| 21 |
+
import io
|
| 22 |
+
import os
|
| 23 |
+
import warnings
|
| 24 |
+
from collections.abc import Sequence
|
| 25 |
+
from io import BytesIO
|
| 26 |
+
from typing import TYPE_CHECKING, Any, Union
|
| 27 |
+
|
| 28 |
+
import httpx
|
| 29 |
+
import numpy as np
|
| 30 |
+
from packaging import version
|
| 31 |
+
|
| 32 |
+
from .utils import (
|
| 33 |
+
is_librosa_available,
|
| 34 |
+
is_numpy_array,
|
| 35 |
+
is_soundfile_available,
|
| 36 |
+
is_torch_tensor,
|
| 37 |
+
is_torchcodec_available,
|
| 38 |
+
requires_backends,
|
| 39 |
+
)
|
| 40 |
+
from .utils.generic import retry
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if TYPE_CHECKING:
|
| 44 |
+
import torch
|
| 45 |
+
|
| 46 |
+
if is_soundfile_available():
|
| 47 |
+
import soundfile as sf
|
| 48 |
+
|
| 49 |
+
if is_librosa_available():
|
| 50 |
+
import librosa
|
| 51 |
+
|
| 52 |
+
# TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa
|
| 53 |
+
import soxr
|
| 54 |
+
|
| 55 |
+
if is_torchcodec_available():
|
| 56 |
+
TORCHCODEC_VERSION = version.parse(importlib.metadata.version("torchcodec"))
|
| 57 |
+
|
| 58 |
+
AudioInput = Union[np.ndarray, "torch.Tensor", Sequence[np.ndarray], Sequence["torch.Tensor"]]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@retry(exceptions=(httpx.HTTPError,))
|
| 62 |
+
def _fetch_audio_bytes(url: str, timeout: float | None = 10.0) -> bytes:
|
| 63 |
+
"""Fetch audio bytes from a URL with automatic retry and exponential backoff."""
|
| 64 |
+
response = httpx.get(url, follow_redirects=True, timeout=timeout)
|
| 65 |
+
response.raise_for_status()
|
| 66 |
+
return response.content
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_audio(audio: str | np.ndarray, sampling_rate=16000, timeout=None) -> np.ndarray:
|
| 70 |
+
"""
|
| 71 |
+
Loads `audio` to an np.ndarray object.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
audio (`str` or `np.ndarray`):
|
| 75 |
+
The audio to be loaded to the numpy array format.
|
| 76 |
+
sampling_rate (`int`, *optional*, defaults to 16000):
|
| 77 |
+
The sampling rate to be used when loading the audio. It should be same as the
|
| 78 |
+
sampling rate the model you will be using further was trained with.
|
| 79 |
+
timeout (`float`, *optional*):
|
| 80 |
+
The timeout value in seconds for the URL request.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
`np.ndarray`: A numpy array representing the audio.
|
| 84 |
+
"""
|
| 85 |
+
if isinstance(audio, str):
|
| 86 |
+
# Try to load with `torchcodec` but do not enforce users to install it. If not found
|
| 87 |
+
# fallback to `librosa`. If using an audio-only model, most probably `torchcodec` won't be
|
| 88 |
+
# needed. Do not raise any errors if not installed or versions do not match
|
| 89 |
+
if is_torchcodec_available() and version.parse("0.3.0") <= TORCHCODEC_VERSION:
|
| 90 |
+
audio = load_audio_torchcodec(audio, sampling_rate=sampling_rate, timeout=timeout)
|
| 91 |
+
elif audio.rsplit("?", 1)[0].lower().endswith((".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv")):
|
| 92 |
+
raise RuntimeError(
|
| 93 |
+
f"The audio source appears to be a video file ('{audio.split('/')[-1]}'). "
|
| 94 |
+
"librosa cannot decode video containers. "
|
| 95 |
+
"Install torchcodec>=0.3.0 (`pip install torchcodec`) to load audio from video files."
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
audio = load_audio_librosa(audio, sampling_rate=sampling_rate, timeout=timeout)
|
| 99 |
+
elif not isinstance(audio, np.ndarray):
|
| 100 |
+
raise TypeError(
|
| 101 |
+
"Incorrect format used for `audio`. Should be an url linking to an audio, a local path, or numpy array."
|
| 102 |
+
)
|
| 103 |
+
return audio
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def load_audio_torchcodec(audio: str | np.ndarray, sampling_rate=16000, timeout=None) -> np.ndarray:
|
| 107 |
+
"""
|
| 108 |
+
Loads `audio` to an np.ndarray object using `torchcodec`.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
audio (`str` or `np.ndarray`):
|
| 112 |
+
The audio to be loaded to the numpy array format.
|
| 113 |
+
sampling_rate (`int`, *optional*, defaults to 16000):
|
| 114 |
+
The sampling rate to be used when loading the audio. It should be same as the
|
| 115 |
+
sampling rate the model you will be using further was trained with.
|
| 116 |
+
timeout (`float`, *optional*):
|
| 117 |
+
The timeout value in seconds for the URL request.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
`np.ndarray`: A numpy array representing the audio.
|
| 121 |
+
"""
|
| 122 |
+
# Lazy import so that issues in torchcodec compatibility don't crash the whole library
|
| 123 |
+
requires_backends(load_audio_torchcodec, ["torchcodec"])
|
| 124 |
+
from torchcodec.decoders import AudioDecoder
|
| 125 |
+
|
| 126 |
+
# Fetch bytes for URLs so we get retry logic; torchcodec does not surface ffmpeg network retries options
|
| 127 |
+
if isinstance(audio, str) and audio.startswith(("http://", "https://")):
|
| 128 |
+
audio = _fetch_audio_bytes(audio, timeout=timeout)
|
| 129 |
+
|
| 130 |
+
# Set `num_channels` to `1` which is what most models expects and the default in librosa
|
| 131 |
+
decoder = AudioDecoder(audio, sample_rate=sampling_rate, num_channels=1)
|
| 132 |
+
audio = decoder.get_all_samples().data[0].numpy() # NOTE: feature extractors don't accept torch tensors
|
| 133 |
+
return audio
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_audio_librosa(audio: str | np.ndarray, sampling_rate=16000, timeout=None) -> np.ndarray:
|
| 137 |
+
"""
|
| 138 |
+
Loads `audio` to an np.ndarray object using `librosa`.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
audio (`str` or `np.ndarray`):
|
| 142 |
+
The audio to be loaded to the numpy array format.
|
| 143 |
+
sampling_rate (`int`, *optional*, defaults to 16000):
|
| 144 |
+
The sampling rate to be used when loading the audio. It should be same as the
|
| 145 |
+
sampling rate the model you will be using further was trained with.
|
| 146 |
+
timeout (`float`, *optional*):
|
| 147 |
+
The timeout value in seconds for the URL request.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
`np.ndarray`: A numpy array representing the audio.
|
| 151 |
+
"""
|
| 152 |
+
requires_backends(load_audio_librosa, ["librosa"])
|
| 153 |
+
|
| 154 |
+
# Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav)
|
| 155 |
+
if audio.startswith("http://") or audio.startswith("https://"):
|
| 156 |
+
audio = librosa.load(BytesIO(_fetch_audio_bytes(audio, timeout=timeout)), sr=sampling_rate)[0]
|
| 157 |
+
elif os.path.isfile(audio):
|
| 158 |
+
audio = librosa.load(audio, sr=sampling_rate)[0]
|
| 159 |
+
return audio
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def load_audio_as(
|
| 163 |
+
audio: str,
|
| 164 |
+
return_format: str,
|
| 165 |
+
timeout: int | None = None,
|
| 166 |
+
force_mono: bool = False,
|
| 167 |
+
sampling_rate: int | None = None,
|
| 168 |
+
) -> str | dict[str, Any] | io.BytesIO | None:
|
| 169 |
+
"""
|
| 170 |
+
Load audio from either a local file path or URL and return in specified format.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
audio (`str`): Either a local file path or a URL to an audio file
|
| 174 |
+
return_format (`str`): Format to return the audio in:
|
| 175 |
+
- "base64": Base64 encoded string
|
| 176 |
+
- "dict": Dictionary with data and format
|
| 177 |
+
- "buffer": BytesIO object
|
| 178 |
+
timeout (`int`, *optional*): Timeout for URL requests in seconds
|
| 179 |
+
force_mono (`bool`): Whether to convert stereo audio to mono
|
| 180 |
+
sampling_rate (`int`, *optional*): If provided, the audio will be resampled to the specified sampling rate.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
`Union[str, Dict[str, Any], io.BytesIO, None]`:
|
| 184 |
+
- `str`: Base64 encoded audio data (if return_format="base64")
|
| 185 |
+
- `dict`: Dictionary with 'data' (base64 encoded audio data) and 'format' keys (if return_format="dict")
|
| 186 |
+
- `io.BytesIO`: BytesIO object containing audio data (if return_format="buffer")
|
| 187 |
+
"""
|
| 188 |
+
requires_backends(load_audio_as, ["librosa"])
|
| 189 |
+
|
| 190 |
+
if return_format not in ["base64", "dict", "buffer"]:
|
| 191 |
+
raise ValueError(f"Invalid return_format: {return_format}. Must be 'base64', 'dict', or 'buffer'")
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
# Load audio bytes from URL or file
|
| 195 |
+
audio_bytes = None
|
| 196 |
+
if audio.startswith(("http://", "https://")):
|
| 197 |
+
audio_bytes = _fetch_audio_bytes(audio, timeout=timeout)
|
| 198 |
+
elif os.path.isfile(audio):
|
| 199 |
+
with open(audio, "rb") as audio_file:
|
| 200 |
+
audio_bytes = audio_file.read()
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError(f"File not found: {audio}")
|
| 203 |
+
|
| 204 |
+
# Process audio data
|
| 205 |
+
with io.BytesIO(audio_bytes) as audio_file:
|
| 206 |
+
with sf.SoundFile(audio_file) as f:
|
| 207 |
+
audio_array = f.read(dtype="float32")
|
| 208 |
+
original_sr = f.samplerate
|
| 209 |
+
audio_format = f.format
|
| 210 |
+
if sampling_rate is not None and sampling_rate != original_sr:
|
| 211 |
+
# Resample audio to target sampling rate
|
| 212 |
+
audio_array = soxr.resample(audio_array, original_sr, sampling_rate, quality="HQ")
|
| 213 |
+
else:
|
| 214 |
+
sampling_rate = original_sr
|
| 215 |
+
|
| 216 |
+
# Convert to mono if needed
|
| 217 |
+
if force_mono and audio_array.ndim != 1:
|
| 218 |
+
audio_array = audio_array.mean(axis=1)
|
| 219 |
+
|
| 220 |
+
buffer = io.BytesIO()
|
| 221 |
+
sf.write(buffer, audio_array, sampling_rate, format=audio_format.upper())
|
| 222 |
+
buffer.seek(0)
|
| 223 |
+
|
| 224 |
+
if return_format == "buffer":
|
| 225 |
+
return buffer
|
| 226 |
+
elif return_format == "base64":
|
| 227 |
+
return base64.b64encode(buffer.read()).decode("utf-8")
|
| 228 |
+
elif return_format == "dict":
|
| 229 |
+
return {
|
| 230 |
+
"data": base64.b64encode(buffer.read()).decode("utf-8"),
|
| 231 |
+
"format": audio_format.lower(),
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
except Exception as e:
|
| 235 |
+
raise ValueError(f"Error loading audio: {e}")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def conv1d_output_length(module: "torch.nn.Conv1d", input_length: int) -> int:
|
| 239 |
+
"""
|
| 240 |
+
Computes the output length of a 1D convolution layer according to torch's documentation:
|
| 241 |
+
https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
| 242 |
+
"""
|
| 243 |
+
return int(
|
| 244 |
+
(input_length + 2 * module.padding[0] - module.dilation[0] * (module.kernel_size[0] - 1) - 1)
|
| 245 |
+
/ module.stride[0]
|
| 246 |
+
+ 1
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def is_valid_audio(audio):
|
| 251 |
+
return is_numpy_array(audio) or is_torch_tensor(audio)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def is_valid_list_of_audio(audio):
|
| 255 |
+
return audio and all(is_valid_audio(audio_i) for audio_i in audio)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def make_list_of_audio(
|
| 259 |
+
audio: list[AudioInput] | AudioInput,
|
| 260 |
+
) -> AudioInput:
|
| 261 |
+
"""
|
| 262 |
+
Ensure that the output is a list of audio.
|
| 263 |
+
Args:
|
| 264 |
+
audio (`Union[list[AudioInput], AudioInput]`):
|
| 265 |
+
The input audio.
|
| 266 |
+
Returns:
|
| 267 |
+
list: A list of audio.
|
| 268 |
+
"""
|
| 269 |
+
# If it's a list of audios, it's already in the right format
|
| 270 |
+
if isinstance(audio, (list, tuple)) and is_valid_list_of_audio(audio):
|
| 271 |
+
return audio
|
| 272 |
+
|
| 273 |
+
# If it's a single audio, convert it to a list of
|
| 274 |
+
if is_valid_audio(audio):
|
| 275 |
+
return [audio]
|
| 276 |
+
|
| 277 |
+
raise ValueError("Invalid input type. Must be a single audio or a list of audio")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def hertz_to_mel(freq: float | np.ndarray, mel_scale: str = "htk") -> float | np.ndarray:
|
| 281 |
+
"""
|
| 282 |
+
Convert frequency from hertz to mels.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
freq (`float` or `np.ndarray`):
|
| 286 |
+
The frequency, or multiple frequencies, in hertz (Hz).
|
| 287 |
+
mel_scale (`str`, *optional*, defaults to `"htk"`):
|
| 288 |
+
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
`float` or `np.ndarray`: The frequencies on the mel scale.
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
if mel_scale not in ["slaney", "htk", "kaldi"]:
|
| 295 |
+
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
|
| 296 |
+
|
| 297 |
+
if mel_scale == "htk":
|
| 298 |
+
return 2595.0 * np.log10(1.0 + (freq / 700.0))
|
| 299 |
+
elif mel_scale == "kaldi":
|
| 300 |
+
return 1127.0 * np.log(1.0 + (freq / 700.0))
|
| 301 |
+
|
| 302 |
+
min_log_hertz = 1000.0
|
| 303 |
+
min_log_mel = 15.0
|
| 304 |
+
logstep = 27.0 / np.log(6.4)
|
| 305 |
+
mels = 3.0 * freq / 200.0
|
| 306 |
+
|
| 307 |
+
if isinstance(freq, np.ndarray):
|
| 308 |
+
log_region = freq >= min_log_hertz
|
| 309 |
+
mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
|
| 310 |
+
elif freq >= min_log_hertz:
|
| 311 |
+
mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
|
| 312 |
+
|
| 313 |
+
return mels
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def mel_to_hertz(mels: float | np.ndarray, mel_scale: str = "htk") -> float | np.ndarray:
|
| 317 |
+
"""
|
| 318 |
+
Convert frequency from mels to hertz.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
mels (`float` or `np.ndarray`):
|
| 322 |
+
The frequency, or multiple frequencies, in mels.
|
| 323 |
+
mel_scale (`str`, *optional*, `"htk"`):
|
| 324 |
+
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
`float` or `np.ndarray`: The frequencies in hertz.
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
if mel_scale not in ["slaney", "htk", "kaldi"]:
|
| 331 |
+
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
|
| 332 |
+
|
| 333 |
+
if mel_scale == "htk":
|
| 334 |
+
return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
|
| 335 |
+
elif mel_scale == "kaldi":
|
| 336 |
+
return 700.0 * (np.exp(mels / 1127.0) - 1.0)
|
| 337 |
+
|
| 338 |
+
min_log_hertz = 1000.0
|
| 339 |
+
min_log_mel = 15.0
|
| 340 |
+
logstep = np.log(6.4) / 27.0
|
| 341 |
+
freq = 200.0 * mels / 3.0
|
| 342 |
+
|
| 343 |
+
if isinstance(mels, np.ndarray):
|
| 344 |
+
log_region = mels >= min_log_mel
|
| 345 |
+
freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
|
| 346 |
+
elif mels >= min_log_mel:
|
| 347 |
+
freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
|
| 348 |
+
|
| 349 |
+
return freq
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def hertz_to_octave(freq: float | np.ndarray, tuning: float = 0.0, bins_per_octave: int = 12):
|
| 353 |
+
"""
|
| 354 |
+
Convert frequency from hertz to fractional octave numbers.
|
| 355 |
+
Adapted from *librosa*.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
freq (`float` or `np.ndarray`):
|
| 359 |
+
The frequency, or multiple frequencies, in hertz (Hz).
|
| 360 |
+
tuning (`float`, defaults to `0.`):
|
| 361 |
+
Tuning deviation from the Stuttgart pitch (A440) in (fractional) bins per octave.
|
| 362 |
+
bins_per_octave (`int`, defaults to `12`):
|
| 363 |
+
Number of bins per octave.
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
`float` or `np.ndarray`: The frequencies on the octave scale.
|
| 367 |
+
"""
|
| 368 |
+
stuttgart_pitch = 440.0 * 2.0 ** (tuning / bins_per_octave)
|
| 369 |
+
octave = np.log2(freq / (float(stuttgart_pitch) / 16))
|
| 370 |
+
return octave
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
|
| 374 |
+
"""
|
| 375 |
+
Creates a triangular filter bank.
|
| 376 |
+
|
| 377 |
+
Adapted from *torchaudio* and *librosa*.
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
|
| 381 |
+
Discrete frequencies of the FFT bins in Hz.
|
| 382 |
+
filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
|
| 383 |
+
Center frequencies of the triangular filters to create, in Hz.
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
`np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
|
| 387 |
+
"""
|
| 388 |
+
filter_diff = np.diff(filter_freqs)
|
| 389 |
+
slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
|
| 390 |
+
down_slopes = -slopes[:, :-2] / filter_diff[:-1]
|
| 391 |
+
up_slopes = slopes[:, 2:] / filter_diff[1:]
|
| 392 |
+
return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def chroma_filter_bank(
|
| 396 |
+
num_frequency_bins: int,
|
| 397 |
+
num_chroma: int,
|
| 398 |
+
sampling_rate: int,
|
| 399 |
+
tuning: float = 0.0,
|
| 400 |
+
power: float | None = 2.0,
|
| 401 |
+
weighting_parameters: tuple[float, float] | None = (5.0, 2.0),
|
| 402 |
+
start_at_c_chroma: bool = True,
|
| 403 |
+
):
|
| 404 |
+
"""
|
| 405 |
+
Creates a chroma filter bank, i.e a linear transformation to project spectrogram bins onto chroma bins.
|
| 406 |
+
|
| 407 |
+
Adapted from *librosa*.
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
num_frequency_bins (`int`):
|
| 411 |
+
Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
|
| 412 |
+
num_chroma (`int`):
|
| 413 |
+
Number of chroma bins (i.e pitch classes).
|
| 414 |
+
sampling_rate (`float`):
|
| 415 |
+
Sample rate of the audio waveform.
|
| 416 |
+
tuning (`float`):
|
| 417 |
+
Tuning deviation from A440 in fractions of a chroma bin.
|
| 418 |
+
power (`float`, *optional*, defaults to 2.0):
|
| 419 |
+
If 12.0, normalizes each column with their L2 norm. If 1.0, normalizes each column with their L1 norm.
|
| 420 |
+
weighting_parameters (`tuple[float, float]`, *optional*, defaults to `(5., 2.)`):
|
| 421 |
+
If specified, apply a Gaussian weighting parameterized by the first element of the tuple being the center and
|
| 422 |
+
the second element being the Gaussian half-width.
|
| 423 |
+
start_at_c_chroma (`bool`, *optional*, defaults to `True`):
|
| 424 |
+
If True, the filter bank will start at the 'C' pitch class. Otherwise, it will start at 'A'.
|
| 425 |
+
Returns:
|
| 426 |
+
`np.ndarray` of shape `(num_frequency_bins, num_chroma)`
|
| 427 |
+
"""
|
| 428 |
+
# Get the FFT bins, not counting the DC component
|
| 429 |
+
frequencies = np.linspace(0, sampling_rate, num_frequency_bins, endpoint=False)[1:]
|
| 430 |
+
|
| 431 |
+
freq_bins = num_chroma * hertz_to_octave(frequencies, tuning=tuning, bins_per_octave=num_chroma)
|
| 432 |
+
|
| 433 |
+
# make up a value for the 0 Hz bin = 1.5 octaves below bin 1
|
| 434 |
+
# (so chroma is 50% rotated from bin 1, and bin width is broad)
|
| 435 |
+
freq_bins = np.concatenate(([freq_bins[0] - 1.5 * num_chroma], freq_bins))
|
| 436 |
+
|
| 437 |
+
bins_width = np.concatenate((np.maximum(freq_bins[1:] - freq_bins[:-1], 1.0), [1]))
|
| 438 |
+
|
| 439 |
+
chroma_filters = np.subtract.outer(freq_bins, np.arange(0, num_chroma, dtype="d")).T
|
| 440 |
+
|
| 441 |
+
num_chroma2 = np.round(float(num_chroma) / 2)
|
| 442 |
+
|
| 443 |
+
# Project into range -num_chroma/2 .. num_chroma/2
|
| 444 |
+
# add on fixed offset of 10*num_chroma to ensure all values passed to
|
| 445 |
+
# rem are positive
|
| 446 |
+
chroma_filters = np.remainder(chroma_filters + num_chroma2 + 10 * num_chroma, num_chroma) - num_chroma2
|
| 447 |
+
|
| 448 |
+
# Gaussian bumps - 2*D to make them narrower
|
| 449 |
+
chroma_filters = np.exp(-0.5 * (2 * chroma_filters / np.tile(bins_width, (num_chroma, 1))) ** 2)
|
| 450 |
+
|
| 451 |
+
# normalize each column
|
| 452 |
+
if power is not None:
|
| 453 |
+
chroma_filters = chroma_filters / np.sum(chroma_filters**power, axis=0, keepdims=True) ** (1.0 / power)
|
| 454 |
+
|
| 455 |
+
# Maybe apply scaling for fft bins
|
| 456 |
+
if weighting_parameters is not None:
|
| 457 |
+
center, half_width = weighting_parameters
|
| 458 |
+
chroma_filters *= np.tile(
|
| 459 |
+
np.exp(-0.5 * (((freq_bins / num_chroma - center) / half_width) ** 2)),
|
| 460 |
+
(num_chroma, 1),
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
if start_at_c_chroma:
|
| 464 |
+
chroma_filters = np.roll(chroma_filters, -3 * (num_chroma // 12), axis=0)
|
| 465 |
+
|
| 466 |
+
# remove aliasing columns, copy to ensure row-contiguity
|
| 467 |
+
return np.ascontiguousarray(chroma_filters[:, : int(1 + num_frequency_bins / 2)])
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def mel_filter_bank(
|
| 471 |
+
num_frequency_bins: int,
|
| 472 |
+
num_mel_filters: int,
|
| 473 |
+
min_frequency: float,
|
| 474 |
+
max_frequency: float,
|
| 475 |
+
sampling_rate: int,
|
| 476 |
+
norm: str | None = None,
|
| 477 |
+
mel_scale: str = "htk",
|
| 478 |
+
triangularize_in_mel_space: bool = False,
|
| 479 |
+
) -> np.ndarray:
|
| 480 |
+
"""
|
| 481 |
+
Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
|
| 482 |
+
various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters
|
| 483 |
+
are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
|
| 484 |
+
features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
|
| 485 |
+
|
| 486 |
+
Different banks of mel filters were introduced in the literature. The following variations are supported:
|
| 487 |
+
|
| 488 |
+
- MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech
|
| 489 |
+
bandwidth of `[0, 4600]` Hz.
|
| 490 |
+
- MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech
|
| 491 |
+
bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz.
|
| 492 |
+
- MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and
|
| 493 |
+
speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization.
|
| 494 |
+
- HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of
|
| 495 |
+
12.5 kHz and speech bandwidth of `[0, 6250]` Hz.
|
| 496 |
+
|
| 497 |
+
This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's
|
| 498 |
+
`melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation.
|
| 499 |
+
|
| 500 |
+
Args:
|
| 501 |
+
num_frequency_bins (`int`):
|
| 502 |
+
Number of frequency bins (should be the same as `n_fft // 2 + 1` where `n_fft` is the size of the Fourier Transform used to compute the spectrogram).
|
| 503 |
+
num_mel_filters (`int`):
|
| 504 |
+
Number of mel filters to generate.
|
| 505 |
+
min_frequency (`float`):
|
| 506 |
+
Lowest frequency of interest in Hz.
|
| 507 |
+
max_frequency (`float`):
|
| 508 |
+
Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`.
|
| 509 |
+
sampling_rate (`int`):
|
| 510 |
+
Sample rate of the audio waveform.
|
| 511 |
+
norm (`str`, *optional*):
|
| 512 |
+
If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
|
| 513 |
+
mel_scale (`str`, *optional*, defaults to `"htk"`):
|
| 514 |
+
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
|
| 515 |
+
triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
|
| 516 |
+
If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
|
| 517 |
+
should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
|
| 518 |
+
|
| 519 |
+
Returns:
|
| 520 |
+
`np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
|
| 521 |
+
projection matrix to go from a spectrogram to a mel spectrogram.
|
| 522 |
+
"""
|
| 523 |
+
if norm is not None and norm != "slaney":
|
| 524 |
+
raise ValueError('norm must be one of None or "slaney"')
|
| 525 |
+
|
| 526 |
+
if num_frequency_bins < 2:
|
| 527 |
+
raise ValueError(f"Require num_frequency_bins: {num_frequency_bins} >= 2")
|
| 528 |
+
|
| 529 |
+
if min_frequency > max_frequency:
|
| 530 |
+
raise ValueError(f"Require min_frequency: {min_frequency} <= max_frequency: {max_frequency}")
|
| 531 |
+
|
| 532 |
+
# center points of the triangular mel filters
|
| 533 |
+
mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
|
| 534 |
+
mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
|
| 535 |
+
mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
|
| 536 |
+
filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
|
| 537 |
+
|
| 538 |
+
if triangularize_in_mel_space:
|
| 539 |
+
# frequencies of FFT bins in Hz, but filters triangularized in mel space
|
| 540 |
+
fft_bin_width = sampling_rate / ((num_frequency_bins - 1) * 2)
|
| 541 |
+
fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
|
| 542 |
+
filter_freqs = mel_freqs
|
| 543 |
+
else:
|
| 544 |
+
# frequencies of FFT bins in Hz
|
| 545 |
+
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
|
| 546 |
+
|
| 547 |
+
mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
|
| 548 |
+
|
| 549 |
+
if norm is not None and norm == "slaney":
|
| 550 |
+
# Slaney-style mel is scaled to be approx constant energy per channel
|
| 551 |
+
enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
|
| 552 |
+
mel_filters *= np.expand_dims(enorm, 0)
|
| 553 |
+
|
| 554 |
+
if (mel_filters.max(axis=0) == 0.0).any():
|
| 555 |
+
warnings.warn(
|
| 556 |
+
"At least one mel filter has all zero values. "
|
| 557 |
+
f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
|
| 558 |
+
f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
return mel_filters
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def optimal_fft_length(window_length: int) -> int:
|
| 565 |
+
"""
|
| 566 |
+
Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
|
| 567 |
+
already a power of two, rounds it up to the next power or two.
|
| 568 |
+
|
| 569 |
+
The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
|
| 570 |
+
of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
|
| 571 |
+
is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
|
| 572 |
+
it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
|
| 573 |
+
"""
|
| 574 |
+
return 2 ** int(np.ceil(np.log2(window_length)))
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def window_function(
|
| 578 |
+
window_length: int,
|
| 579 |
+
name: str = "hann",
|
| 580 |
+
periodic: bool = True,
|
| 581 |
+
frame_length: int | None = None,
|
| 582 |
+
center: bool = True,
|
| 583 |
+
) -> np.ndarray:
|
| 584 |
+
"""
|
| 585 |
+
Returns an array containing the specified window. This window is intended to be used with `stft`.
|
| 586 |
+
|
| 587 |
+
The following window types are supported:
|
| 588 |
+
|
| 589 |
+
- `"boxcar"`: a rectangular window
|
| 590 |
+
- `"hamming"`: the Hamming window
|
| 591 |
+
- `"hann"`: the Hann window
|
| 592 |
+
- `"povey"`: the Povey window
|
| 593 |
+
|
| 594 |
+
Args:
|
| 595 |
+
window_length (`int`):
|
| 596 |
+
The length of the window in samples.
|
| 597 |
+
name (`str`, *optional*, defaults to `"hann"`):
|
| 598 |
+
The name of the window function.
|
| 599 |
+
periodic (`bool`, *optional*, defaults to `True`):
|
| 600 |
+
Whether the window is periodic or symmetric.
|
| 601 |
+
frame_length (`int`, *optional*):
|
| 602 |
+
The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
|
| 603 |
+
than the frame length, so that it will be zero-padded.
|
| 604 |
+
center (`bool`, *optional*, defaults to `True`):
|
| 605 |
+
Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
`np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
|
| 609 |
+
"""
|
| 610 |
+
length = window_length + 1 if periodic else window_length
|
| 611 |
+
|
| 612 |
+
if name == "boxcar":
|
| 613 |
+
window = np.ones(length)
|
| 614 |
+
elif name in ["hamming", "hamming_window"]:
|
| 615 |
+
window = np.hamming(length)
|
| 616 |
+
elif name in ["hann", "hann_window"]:
|
| 617 |
+
window = np.hanning(length)
|
| 618 |
+
elif name == "povey":
|
| 619 |
+
window = np.power(np.hanning(length), 0.85)
|
| 620 |
+
else:
|
| 621 |
+
raise ValueError(f"Unknown window function '{name}'")
|
| 622 |
+
|
| 623 |
+
if periodic:
|
| 624 |
+
window = window[:-1]
|
| 625 |
+
|
| 626 |
+
if frame_length is None:
|
| 627 |
+
return window
|
| 628 |
+
|
| 629 |
+
if window_length > frame_length:
|
| 630 |
+
raise ValueError(
|
| 631 |
+
f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
padded_window = np.zeros(frame_length)
|
| 635 |
+
offset = (frame_length - window_length) // 2 if center else 0
|
| 636 |
+
padded_window[offset : offset + window_length] = window
|
| 637 |
+
return padded_window
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
# Note: This method processes a single waveform. For batch processing, use spectrogram_batch().
|
| 641 |
+
def spectrogram(
|
| 642 |
+
waveform: np.ndarray,
|
| 643 |
+
window: np.ndarray,
|
| 644 |
+
frame_length: int,
|
| 645 |
+
hop_length: int,
|
| 646 |
+
fft_length: int | None = None,
|
| 647 |
+
power: float | None = 1.0,
|
| 648 |
+
center: bool = True,
|
| 649 |
+
pad_mode: str = "reflect",
|
| 650 |
+
onesided: bool = True,
|
| 651 |
+
dither: float = 0.0,
|
| 652 |
+
preemphasis: float | None = None,
|
| 653 |
+
mel_filters: np.ndarray | None = None,
|
| 654 |
+
mel_floor: float = 1e-10,
|
| 655 |
+
log_mel: str | None = None,
|
| 656 |
+
reference: float = 1.0,
|
| 657 |
+
min_value: float = 1e-10,
|
| 658 |
+
db_range: float | None = None,
|
| 659 |
+
remove_dc_offset: bool = False,
|
| 660 |
+
dtype: np.dtype = np.float32,
|
| 661 |
+
) -> np.ndarray:
|
| 662 |
+
"""
|
| 663 |
+
Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
|
| 664 |
+
|
| 665 |
+
This function can create the following kinds of spectrograms:
|
| 666 |
+
|
| 667 |
+
- amplitude spectrogram (`power = 1.0`)
|
| 668 |
+
- power spectrogram (`power = 2.0`)
|
| 669 |
+
- complex-valued spectrogram (`power = None`)
|
| 670 |
+
- log spectrogram (use `log_mel` argument)
|
| 671 |
+
- mel spectrogram (provide `mel_filters`)
|
| 672 |
+
- log-mel spectrogram (provide `mel_filters` and `log_mel`)
|
| 673 |
+
|
| 674 |
+
How this works:
|
| 675 |
+
|
| 676 |
+
1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
|
| 677 |
+
- hop_length` samples.
|
| 678 |
+
2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
|
| 679 |
+
3. The DFT is taken of each windowed frame.
|
| 680 |
+
4. The results are stacked into a spectrogram.
|
| 681 |
+
|
| 682 |
+
We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
|
| 683 |
+
|
| 684 |
+
- The analysis frame. This is the size of the time slices that the input waveform is split into.
|
| 685 |
+
- The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
|
| 686 |
+
- The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
|
| 687 |
+
|
| 688 |
+
In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
|
| 689 |
+
padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
|
| 690 |
+
typically the next power of two.
|
| 691 |
+
|
| 692 |
+
Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
|
| 693 |
+
`torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
|
| 694 |
+
can be constructed.
|
| 695 |
+
|
| 696 |
+
Args:
|
| 697 |
+
waveform (`np.ndarray` of shape `(length,)`):
|
| 698 |
+
The input waveform. This must be a single real-valued, mono waveform.
|
| 699 |
+
window (`np.ndarray` of shape `(frame_length,)`):
|
| 700 |
+
The windowing function to apply, including zero-padding if necessary. The actual window length may be
|
| 701 |
+
shorter than `frame_length`, but we're assuming the array has already been zero-padded.
|
| 702 |
+
frame_length (`int`):
|
| 703 |
+
The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also
|
| 704 |
+
allow smaller sizes.
|
| 705 |
+
hop_length (`int`):
|
| 706 |
+
The stride between successive analysis frames in samples.
|
| 707 |
+
fft_length (`int`, *optional*):
|
| 708 |
+
The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
|
| 709 |
+
For optimal speed, this should be a power of two. If `None`, uses `frame_length`.
|
| 710 |
+
power (`float`, *optional*, defaults to 1.0):
|
| 711 |
+
If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns
|
| 712 |
+
complex numbers.
|
| 713 |
+
center (`bool`, *optional*, defaults to `True`):
|
| 714 |
+
Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
|
| 715 |
+
`t` will start at time `t * hop_length`.
|
| 716 |
+
pad_mode (`str`, *optional*, defaults to `"reflect"`):
|
| 717 |
+
Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"`
|
| 718 |
+
(pad with edge values), `"reflect"` (pads with mirrored values).
|
| 719 |
+
onesided (`bool`, *optional*, defaults to `True`):
|
| 720 |
+
If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
|
| 721 |
+
frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
|
| 722 |
+
dither (`float`, *optional*, defaults to 0.0):
|
| 723 |
+
Adds dithering. In other words, adds a small Gaussian noise to each frame.
|
| 724 |
+
E.g. use 4.0 to add dithering with a normal distribution centered
|
| 725 |
+
around 0.0 with standard deviation 4.0, 0.0 means no dithering.
|
| 726 |
+
Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
|
| 727 |
+
values for signals with hard-zero sections, when VAD cutoff is present in the signal.
|
| 728 |
+
preemphasis (`float`, *optional*)
|
| 729 |
+
Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
|
| 730 |
+
mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
|
| 731 |
+
The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram.
|
| 732 |
+
mel_floor (`float`, *optional*, defaults to 1e-10):
|
| 733 |
+
Minimum value of mel frequency banks.
|
| 734 |
+
log_mel (`str`, *optional*):
|
| 735 |
+
How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take
|
| 736 |
+
the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be
|
| 737 |
+
used when `power` is not `None`.
|
| 738 |
+
reference (`float`, *optional*, defaults to 1.0):
|
| 739 |
+
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
| 740 |
+
the loudest part to 0 dB. Must be greater than zero.
|
| 741 |
+
min_value (`float`, *optional*, defaults to `1e-10`):
|
| 742 |
+
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
| 743 |
+
`log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an
|
| 744 |
+
amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero.
|
| 745 |
+
db_range (`float`, *optional*):
|
| 746 |
+
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
| 747 |
+
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
| 748 |
+
remove_dc_offset (`bool`, *optional*):
|
| 749 |
+
Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
|
| 750 |
+
order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
|
| 751 |
+
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
|
| 752 |
+
Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
|
| 753 |
+
`np.complex64`.
|
| 754 |
+
|
| 755 |
+
Returns:
|
| 756 |
+
`nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
|
| 757 |
+
`(num_mel_filters, length)` for a mel spectrogram.
|
| 758 |
+
"""
|
| 759 |
+
window_length = len(window)
|
| 760 |
+
|
| 761 |
+
if fft_length is None:
|
| 762 |
+
fft_length = frame_length
|
| 763 |
+
|
| 764 |
+
if frame_length > fft_length:
|
| 765 |
+
raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
|
| 766 |
+
|
| 767 |
+
if window_length != frame_length:
|
| 768 |
+
raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
|
| 769 |
+
|
| 770 |
+
if hop_length <= 0:
|
| 771 |
+
raise ValueError("hop_length must be greater than zero")
|
| 772 |
+
|
| 773 |
+
if waveform.ndim != 1:
|
| 774 |
+
raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
|
| 775 |
+
|
| 776 |
+
if np.iscomplexobj(waveform):
|
| 777 |
+
raise ValueError("Complex-valued input waveforms are not currently supported")
|
| 778 |
+
|
| 779 |
+
if power is None and mel_filters is not None:
|
| 780 |
+
raise ValueError(
|
| 781 |
+
"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram."
|
| 782 |
+
"Specify `power` to fix this issue."
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
# center pad the waveform
|
| 786 |
+
if center:
|
| 787 |
+
padding = [(int(frame_length // 2), int(frame_length // 2))]
|
| 788 |
+
waveform = np.pad(waveform, padding, mode=pad_mode)
|
| 789 |
+
|
| 790 |
+
# promote to float64, since np.fft uses float64 internally
|
| 791 |
+
waveform = waveform.astype(np.float64)
|
| 792 |
+
window = window.astype(np.float64)
|
| 793 |
+
|
| 794 |
+
# split waveform into frames of frame_length size
|
| 795 |
+
num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
|
| 796 |
+
|
| 797 |
+
num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
|
| 798 |
+
spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
|
| 799 |
+
|
| 800 |
+
# rfft is faster than fft
|
| 801 |
+
fft_func = np.fft.rfft if onesided else np.fft.fft
|
| 802 |
+
buffer = np.zeros(fft_length)
|
| 803 |
+
|
| 804 |
+
timestep = 0
|
| 805 |
+
for frame_idx in range(num_frames):
|
| 806 |
+
buffer[:frame_length] = waveform[timestep : timestep + frame_length]
|
| 807 |
+
|
| 808 |
+
if dither != 0.0:
|
| 809 |
+
buffer[:frame_length] += dither * np.random.randn(frame_length)
|
| 810 |
+
|
| 811 |
+
if remove_dc_offset:
|
| 812 |
+
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
|
| 813 |
+
|
| 814 |
+
if preemphasis is not None:
|
| 815 |
+
buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
|
| 816 |
+
buffer[0] *= 1 - preemphasis
|
| 817 |
+
|
| 818 |
+
buffer[:frame_length] *= window
|
| 819 |
+
|
| 820 |
+
spectrogram[frame_idx] = fft_func(buffer)
|
| 821 |
+
timestep += hop_length
|
| 822 |
+
|
| 823 |
+
# note: ** is much faster than np.power
|
| 824 |
+
if power is not None:
|
| 825 |
+
spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
|
| 826 |
+
|
| 827 |
+
spectrogram = spectrogram.T
|
| 828 |
+
|
| 829 |
+
if mel_filters is not None:
|
| 830 |
+
spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
|
| 831 |
+
|
| 832 |
+
if power is not None and log_mel is not None:
|
| 833 |
+
if log_mel == "log":
|
| 834 |
+
spectrogram = np.log(spectrogram)
|
| 835 |
+
elif log_mel == "log10":
|
| 836 |
+
spectrogram = np.log10(spectrogram)
|
| 837 |
+
elif log_mel == "dB":
|
| 838 |
+
if power == 1.0:
|
| 839 |
+
spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
|
| 840 |
+
elif power == 2.0:
|
| 841 |
+
spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
|
| 842 |
+
else:
|
| 843 |
+
raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
|
| 844 |
+
else:
|
| 845 |
+
raise ValueError(f"Unknown log_mel option: {log_mel}")
|
| 846 |
+
|
| 847 |
+
spectrogram = np.asarray(spectrogram, dtype)
|
| 848 |
+
|
| 849 |
+
return spectrogram
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
def spectrogram_batch(
|
| 853 |
+
waveform_list: list[np.ndarray],
|
| 854 |
+
window: np.ndarray,
|
| 855 |
+
frame_length: int,
|
| 856 |
+
hop_length: int,
|
| 857 |
+
fft_length: int | None = None,
|
| 858 |
+
power: float | None = 1.0,
|
| 859 |
+
center: bool = True,
|
| 860 |
+
pad_mode: str = "reflect",
|
| 861 |
+
onesided: bool = True,
|
| 862 |
+
dither: float = 0.0,
|
| 863 |
+
preemphasis: float | None = None,
|
| 864 |
+
mel_filters: np.ndarray | None = None,
|
| 865 |
+
mel_floor: float = 1e-10,
|
| 866 |
+
log_mel: str | None = None,
|
| 867 |
+
reference: float = 1.0,
|
| 868 |
+
min_value: float = 1e-10,
|
| 869 |
+
db_range: float | None = None,
|
| 870 |
+
remove_dc_offset: bool = False,
|
| 871 |
+
dtype: np.dtype = np.float32,
|
| 872 |
+
) -> list[np.ndarray]:
|
| 873 |
+
"""
|
| 874 |
+
Calculates spectrograms for a list of waveforms using the Short-Time Fourier Transform, optimized for batch processing.
|
| 875 |
+
This function extends the capabilities of the `spectrogram` function to handle multiple waveforms efficiently by leveraging broadcasting.
|
| 876 |
+
|
| 877 |
+
It supports generating various types of spectrograms:
|
| 878 |
+
|
| 879 |
+
- amplitude spectrogram (`power = 1.0`)
|
| 880 |
+
- power spectrogram (`power = 2.0`)
|
| 881 |
+
- complex-valued spectrogram (`power = None`)
|
| 882 |
+
- log spectrogram (use `log_mel` argument)
|
| 883 |
+
- mel spectrogram (provide `mel_filters`)
|
| 884 |
+
- log-mel spectrogram (provide `mel_filters` and `log_mel`)
|
| 885 |
+
|
| 886 |
+
How this works:
|
| 887 |
+
|
| 888 |
+
1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
|
| 889 |
+
- hop_length` samples.
|
| 890 |
+
2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
|
| 891 |
+
3. The DFT is taken of each windowed frame.
|
| 892 |
+
4. The results are stacked into a spectrogram.
|
| 893 |
+
|
| 894 |
+
We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
|
| 895 |
+
|
| 896 |
+
- The analysis frame. This is the size of the time slices that the input waveform is split into.
|
| 897 |
+
- The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
|
| 898 |
+
- The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
|
| 899 |
+
|
| 900 |
+
In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
|
| 901 |
+
padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
|
| 902 |
+
typically the next power of two.
|
| 903 |
+
|
| 904 |
+
Note: This function is designed for efficient batch processing of multiple waveforms but retains compatibility with individual waveform processing methods like `librosa.stft`.
|
| 905 |
+
|
| 906 |
+
Args:
|
| 907 |
+
waveform_list (`list[np.ndarray]` with arrays of shape `(length,)`):
|
| 908 |
+
The list of input waveforms, each a single-channel (mono) signal.
|
| 909 |
+
window (`np.ndarray` of shape `(frame_length,)`):
|
| 910 |
+
The windowing function to apply, including zero-padding if necessary.
|
| 911 |
+
frame_length (`int`):
|
| 912 |
+
The length of each frame for analysis.
|
| 913 |
+
hop_length (`int`):
|
| 914 |
+
The step size between successive frames.
|
| 915 |
+
fft_length (`int`, *optional*):
|
| 916 |
+
The size of the FFT buffer, defining frequency bin resolution.
|
| 917 |
+
power (`float`, *optional*, defaults to 1.0):
|
| 918 |
+
Determines the type of spectrogram: 1.0 for amplitude, 2.0 for power, None for complex.
|
| 919 |
+
center (`bool`, *optional*, defaults to `True`):
|
| 920 |
+
Whether to center-pad the waveform frames.
|
| 921 |
+
pad_mode (`str`, *optional*, defaults to `"reflect"`):
|
| 922 |
+
The padding strategy when `center` is `True`.
|
| 923 |
+
onesided (`bool`, *optional*, defaults to `True`):
|
| 924 |
+
If True, returns a one-sided spectrogram for real input signals.
|
| 925 |
+
dither (`float`, *optional*, defaults to 0.0):
|
| 926 |
+
Adds dithering. In other words, adds a small Gaussian noise to each frame.
|
| 927 |
+
E.g. use 4.0 to add dithering with a normal distribution centered
|
| 928 |
+
around 0.0 with standard deviation 4.0, 0.0 means no dithering.
|
| 929 |
+
preemphasis (`float`, *optional*):
|
| 930 |
+
Applies a pre-emphasis filter to each frame.
|
| 931 |
+
mel_filters (`np.ndarray`, *optional*):
|
| 932 |
+
Mel filter bank for converting to mel spectrogram.
|
| 933 |
+
mel_floor (`float`, *optional*, defaults to 1e-10):
|
| 934 |
+
Floor value for mel spectrogram to avoid log(0).
|
| 935 |
+
log_mel (`str`, *optional*):
|
| 936 |
+
Specifies log scaling strategy; options are None, "log", "log10", "dB".
|
| 937 |
+
reference (`float`, *optional*, defaults to 1.0):
|
| 938 |
+
Reference value for dB conversion in log_mel.
|
| 939 |
+
min_value (`float`, *optional*, defaults to 1e-10):
|
| 940 |
+
Minimum floor value for log scale conversions.
|
| 941 |
+
db_range (`float`, *optional*):
|
| 942 |
+
Dynamic range for dB scale spectrograms.
|
| 943 |
+
remove_dc_offset (`bool`, *optional*):
|
| 944 |
+
Whether to remove the DC offset from each frame.
|
| 945 |
+
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
|
| 946 |
+
Data type of the output spectrogram.
|
| 947 |
+
|
| 948 |
+
Returns:
|
| 949 |
+
list[`np.ndarray`]: A list of spectrogram arrays, one for each input waveform.
|
| 950 |
+
"""
|
| 951 |
+
window_length = len(window)
|
| 952 |
+
|
| 953 |
+
if fft_length is None:
|
| 954 |
+
fft_length = frame_length
|
| 955 |
+
|
| 956 |
+
if frame_length > fft_length:
|
| 957 |
+
raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
|
| 958 |
+
|
| 959 |
+
if window_length != frame_length:
|
| 960 |
+
raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
|
| 961 |
+
|
| 962 |
+
if hop_length <= 0:
|
| 963 |
+
raise ValueError("hop_length must be greater than zero")
|
| 964 |
+
|
| 965 |
+
# Check the dimensions of the waveform , and if waveform is complex
|
| 966 |
+
for waveform in waveform_list:
|
| 967 |
+
if waveform.ndim != 1:
|
| 968 |
+
raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
|
| 969 |
+
if np.iscomplexobj(waveform):
|
| 970 |
+
raise ValueError("Complex-valued input waveforms are not currently supported")
|
| 971 |
+
# Center pad the waveform
|
| 972 |
+
if center:
|
| 973 |
+
padding = [(int(frame_length // 2), int(frame_length // 2))]
|
| 974 |
+
waveform_list = [
|
| 975 |
+
np.pad(
|
| 976 |
+
waveform,
|
| 977 |
+
padding,
|
| 978 |
+
mode=pad_mode,
|
| 979 |
+
)
|
| 980 |
+
for waveform in waveform_list
|
| 981 |
+
]
|
| 982 |
+
original_waveform_lengths = [
|
| 983 |
+
len(waveform) for waveform in waveform_list
|
| 984 |
+
] # these lengths will be used to remove padding later
|
| 985 |
+
|
| 986 |
+
# Batch pad the waveform
|
| 987 |
+
max_length = max(original_waveform_lengths)
|
| 988 |
+
padded_waveform_batch = np.array(
|
| 989 |
+
[
|
| 990 |
+
np.pad(waveform, (0, max_length - len(waveform)), mode="constant", constant_values=0)
|
| 991 |
+
for waveform in waveform_list
|
| 992 |
+
],
|
| 993 |
+
dtype=dtype,
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
# Promote to float64, since np.fft uses float64 internally
|
| 997 |
+
padded_waveform_batch = padded_waveform_batch.astype(np.float64)
|
| 998 |
+
window = window.astype(np.float64)
|
| 999 |
+
|
| 1000 |
+
# Split waveform into frames of frame_length size
|
| 1001 |
+
num_frames = int(1 + np.floor((padded_waveform_batch.shape[1] - frame_length) / hop_length))
|
| 1002 |
+
# these lengths will be used to remove padding later
|
| 1003 |
+
true_num_frames = [int(1 + np.floor((length - frame_length) / hop_length)) for length in original_waveform_lengths]
|
| 1004 |
+
num_batches = padded_waveform_batch.shape[0]
|
| 1005 |
+
|
| 1006 |
+
num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
|
| 1007 |
+
spectrogram = np.empty((num_batches, num_frames, num_frequency_bins), dtype=np.complex64)
|
| 1008 |
+
|
| 1009 |
+
# rfft is faster than fft
|
| 1010 |
+
fft_func = np.fft.rfft if onesided else np.fft.fft
|
| 1011 |
+
buffer = np.zeros((num_batches, fft_length))
|
| 1012 |
+
|
| 1013 |
+
for frame_idx in range(num_frames):
|
| 1014 |
+
timestep = frame_idx * hop_length
|
| 1015 |
+
buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]
|
| 1016 |
+
|
| 1017 |
+
if dither != 0.0:
|
| 1018 |
+
buffer[:, :frame_length] += dither * np.random.randn(*buffer[:, :frame_length].shape)
|
| 1019 |
+
|
| 1020 |
+
if remove_dc_offset:
|
| 1021 |
+
buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)
|
| 1022 |
+
|
| 1023 |
+
if preemphasis is not None:
|
| 1024 |
+
buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1]
|
| 1025 |
+
buffer[:, 0] *= 1 - preemphasis
|
| 1026 |
+
|
| 1027 |
+
buffer[:, :frame_length] *= window
|
| 1028 |
+
|
| 1029 |
+
spectrogram[:, frame_idx] = fft_func(buffer)
|
| 1030 |
+
|
| 1031 |
+
# Note: ** is much faster than np.power
|
| 1032 |
+
if power is not None:
|
| 1033 |
+
spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
|
| 1034 |
+
|
| 1035 |
+
# Apply mel filters if provided
|
| 1036 |
+
if mel_filters is not None:
|
| 1037 |
+
result = np.tensordot(spectrogram, mel_filters.T, axes=([2], [1]))
|
| 1038 |
+
spectrogram = np.maximum(mel_floor, result)
|
| 1039 |
+
|
| 1040 |
+
# Convert to log scale if specified
|
| 1041 |
+
if power is not None and log_mel is not None:
|
| 1042 |
+
if log_mel == "log":
|
| 1043 |
+
spectrogram = np.log(spectrogram)
|
| 1044 |
+
elif log_mel == "log10":
|
| 1045 |
+
spectrogram = np.log10(spectrogram)
|
| 1046 |
+
elif log_mel == "dB":
|
| 1047 |
+
if power == 1.0:
|
| 1048 |
+
spectrogram = amplitude_to_db_batch(spectrogram, reference, min_value, db_range)
|
| 1049 |
+
elif power == 2.0:
|
| 1050 |
+
spectrogram = power_to_db_batch(spectrogram, reference, min_value, db_range)
|
| 1051 |
+
else:
|
| 1052 |
+
raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
|
| 1053 |
+
else:
|
| 1054 |
+
raise ValueError(f"Unknown log_mel option: {log_mel}")
|
| 1055 |
+
|
| 1056 |
+
spectrogram = np.asarray(spectrogram, dtype)
|
| 1057 |
+
|
| 1058 |
+
spectrogram_list = [spectrogram[i, : true_num_frames[i], :].T for i in range(len(true_num_frames))]
|
| 1059 |
+
|
| 1060 |
+
return spectrogram_list
|
| 1061 |
+
|
| 1062 |
+
|
| 1063 |
+
def power_to_db(
|
| 1064 |
+
spectrogram: np.ndarray,
|
| 1065 |
+
reference: float = 1.0,
|
| 1066 |
+
min_value: float = 1e-10,
|
| 1067 |
+
db_range: float | None = None,
|
| 1068 |
+
) -> np.ndarray:
|
| 1069 |
+
"""
|
| 1070 |
+
Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
|
| 1071 |
+
logarithm properties for numerical stability.
|
| 1072 |
+
|
| 1073 |
+
The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
|
| 1074 |
+
linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
|
| 1075 |
+
This means that large variations in energy may not sound all that different if the sound is loud to begin with.
|
| 1076 |
+
This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
|
| 1077 |
+
|
| 1078 |
+
Based on the implementation of `librosa.power_to_db`.
|
| 1079 |
+
|
| 1080 |
+
Args:
|
| 1081 |
+
spectrogram (`np.ndarray`):
|
| 1082 |
+
The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
|
| 1083 |
+
reference (`float`, *optional*, defaults to 1.0):
|
| 1084 |
+
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
| 1085 |
+
the loudest part to 0 dB. Must be greater than zero.
|
| 1086 |
+
min_value (`float`, *optional*, defaults to `1e-10`):
|
| 1087 |
+
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
| 1088 |
+
`log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
|
| 1089 |
+
db_range (`float`, *optional*):
|
| 1090 |
+
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
| 1091 |
+
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
| 1092 |
+
|
| 1093 |
+
Returns:
|
| 1094 |
+
`np.ndarray`: the spectrogram in decibels
|
| 1095 |
+
"""
|
| 1096 |
+
if reference <= 0.0:
|
| 1097 |
+
raise ValueError("reference must be greater than zero")
|
| 1098 |
+
if min_value <= 0.0:
|
| 1099 |
+
raise ValueError("min_value must be greater than zero")
|
| 1100 |
+
|
| 1101 |
+
reference = max(min_value, reference)
|
| 1102 |
+
|
| 1103 |
+
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
|
| 1104 |
+
spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
|
| 1105 |
+
|
| 1106 |
+
if db_range is not None:
|
| 1107 |
+
if db_range <= 0.0:
|
| 1108 |
+
raise ValueError("db_range must be greater than zero")
|
| 1109 |
+
spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
|
| 1110 |
+
|
| 1111 |
+
return spectrogram
|
| 1112 |
+
|
| 1113 |
+
|
| 1114 |
+
def power_to_db_batch(
|
| 1115 |
+
spectrogram: np.ndarray,
|
| 1116 |
+
reference: float = 1.0,
|
| 1117 |
+
min_value: float = 1e-10,
|
| 1118 |
+
db_range: float | None = None,
|
| 1119 |
+
) -> np.ndarray:
|
| 1120 |
+
"""
|
| 1121 |
+
Converts a batch of power spectrograms to the decibel scale. This computes `10 * log10(spectrogram / reference)`,
|
| 1122 |
+
using basic logarithm properties for numerical stability.
|
| 1123 |
+
|
| 1124 |
+
This function supports batch processing, where each item in the batch is an individual power (mel) spectrogram.
|
| 1125 |
+
|
| 1126 |
+
Args:
|
| 1127 |
+
spectrogram (`np.ndarray`):
|
| 1128 |
+
The input batch of power (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
|
| 1129 |
+
Note that a power spectrogram has the amplitudes squared!
|
| 1130 |
+
reference (`float`, *optional*, defaults to 1.0):
|
| 1131 |
+
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
| 1132 |
+
the loudest part to 0 dB. Must be greater than zero.
|
| 1133 |
+
min_value (`float`, *optional*, defaults to `1e-10`):
|
| 1134 |
+
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
| 1135 |
+
`log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
|
| 1136 |
+
db_range (`float`, *optional*):
|
| 1137 |
+
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
| 1138 |
+
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
| 1139 |
+
|
| 1140 |
+
Returns:
|
| 1141 |
+
`np.ndarray`: the batch of spectrograms in decibels
|
| 1142 |
+
"""
|
| 1143 |
+
if reference <= 0.0:
|
| 1144 |
+
raise ValueError("reference must be greater than zero")
|
| 1145 |
+
if min_value <= 0.0:
|
| 1146 |
+
raise ValueError("min_value must be greater than zero")
|
| 1147 |
+
|
| 1148 |
+
reference = max(min_value, reference)
|
| 1149 |
+
|
| 1150 |
+
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
|
| 1151 |
+
spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
|
| 1152 |
+
|
| 1153 |
+
if db_range is not None:
|
| 1154 |
+
if db_range <= 0.0:
|
| 1155 |
+
raise ValueError("db_range must be greater than zero")
|
| 1156 |
+
# Apply db_range clipping per batch item
|
| 1157 |
+
max_values = spectrogram.max(axis=(1, 2), keepdims=True)
|
| 1158 |
+
spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
|
| 1159 |
+
|
| 1160 |
+
return spectrogram
|
| 1161 |
+
|
| 1162 |
+
|
| 1163 |
+
def amplitude_to_db(
|
| 1164 |
+
spectrogram: np.ndarray,
|
| 1165 |
+
reference: float = 1.0,
|
| 1166 |
+
min_value: float = 1e-5,
|
| 1167 |
+
db_range: float | None = None,
|
| 1168 |
+
) -> np.ndarray:
|
| 1169 |
+
"""
|
| 1170 |
+
Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
|
| 1171 |
+
basic logarithm properties for numerical stability.
|
| 1172 |
+
|
| 1173 |
+
The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
|
| 1174 |
+
linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
|
| 1175 |
+
This means that large variations in energy may not sound all that different if the sound is loud to begin with.
|
| 1176 |
+
This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
|
| 1177 |
+
|
| 1178 |
+
Args:
|
| 1179 |
+
spectrogram (`np.ndarray`):
|
| 1180 |
+
The input amplitude (mel) spectrogram.
|
| 1181 |
+
reference (`float`, *optional*, defaults to 1.0):
|
| 1182 |
+
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
| 1183 |
+
the loudest part to 0 dB. Must be greater than zero.
|
| 1184 |
+
min_value (`float`, *optional*, defaults to `1e-5`):
|
| 1185 |
+
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
| 1186 |
+
`log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
|
| 1187 |
+
db_range (`float`, *optional*):
|
| 1188 |
+
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
| 1189 |
+
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
| 1190 |
+
|
| 1191 |
+
Returns:
|
| 1192 |
+
`np.ndarray`: the spectrogram in decibels
|
| 1193 |
+
"""
|
| 1194 |
+
if reference <= 0.0:
|
| 1195 |
+
raise ValueError("reference must be greater than zero")
|
| 1196 |
+
if min_value <= 0.0:
|
| 1197 |
+
raise ValueError("min_value must be greater than zero")
|
| 1198 |
+
|
| 1199 |
+
reference = max(min_value, reference)
|
| 1200 |
+
|
| 1201 |
+
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
|
| 1202 |
+
spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
|
| 1203 |
+
|
| 1204 |
+
if db_range is not None:
|
| 1205 |
+
if db_range <= 0.0:
|
| 1206 |
+
raise ValueError("db_range must be greater than zero")
|
| 1207 |
+
spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
|
| 1208 |
+
|
| 1209 |
+
return spectrogram
|
| 1210 |
+
|
| 1211 |
+
|
| 1212 |
+
def amplitude_to_db_batch(
|
| 1213 |
+
spectrogram: np.ndarray, reference: float = 1.0, min_value: float = 1e-5, db_range: float | None = None
|
| 1214 |
+
) -> np.ndarray:
|
| 1215 |
+
"""
|
| 1216 |
+
Converts a batch of amplitude spectrograms to the decibel scale. This computes `20 * log10(spectrogram / reference)`,
|
| 1217 |
+
using basic logarithm properties for numerical stability.
|
| 1218 |
+
|
| 1219 |
+
The function supports batch processing, where each item in the batch is an individual amplitude (mel) spectrogram.
|
| 1220 |
+
|
| 1221 |
+
Args:
|
| 1222 |
+
spectrogram (`np.ndarray`):
|
| 1223 |
+
The input batch of amplitude (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
|
| 1224 |
+
reference (`float`, *optional*, defaults to 1.0):
|
| 1225 |
+
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
| 1226 |
+
the loudest part to 0 dB. Must be greater than zero.
|
| 1227 |
+
min_value (`float`, *optional*, defaults to `1e-5`):
|
| 1228 |
+
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
| 1229 |
+
`log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
|
| 1230 |
+
db_range (`float`, *optional*):
|
| 1231 |
+
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
| 1232 |
+
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
| 1233 |
+
|
| 1234 |
+
Returns:
|
| 1235 |
+
`np.ndarray`: the batch of spectrograms in decibels
|
| 1236 |
+
"""
|
| 1237 |
+
if reference <= 0.0:
|
| 1238 |
+
raise ValueError("reference must be greater than zero")
|
| 1239 |
+
if min_value <= 0.0:
|
| 1240 |
+
raise ValueError("min_value must be greater than zero")
|
| 1241 |
+
|
| 1242 |
+
reference = max(min_value, reference)
|
| 1243 |
+
|
| 1244 |
+
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
|
| 1245 |
+
spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
|
| 1246 |
+
|
| 1247 |
+
if db_range is not None:
|
| 1248 |
+
if db_range <= 0.0:
|
| 1249 |
+
raise ValueError("db_range must be greater than zero")
|
| 1250 |
+
# Apply db_range clipping per batch item
|
| 1251 |
+
max_values = spectrogram.max(axis=(1, 2), keepdims=True)
|
| 1252 |
+
spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
|
| 1253 |
+
|
| 1254 |
+
return spectrogram
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/add_new_model_like.py
ADDED
|
@@ -0,0 +1,790 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import difflib
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
import subprocess
|
| 18 |
+
import textwrap
|
| 19 |
+
from collections.abc import Callable
|
| 20 |
+
from datetime import date
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Annotated, Any, cast
|
| 23 |
+
|
| 24 |
+
import typer
|
| 25 |
+
|
| 26 |
+
from ..utils import is_libcst_available
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# We protect this import to avoid requiring it for all `transformers` CLI commands - however it is actually
|
| 30 |
+
# strictly required for this one (we need it both for modular and for the following Visitor)
|
| 31 |
+
if is_libcst_available():
|
| 32 |
+
import libcst as cst
|
| 33 |
+
from libcst import CSTVisitor
|
| 34 |
+
from libcst import matchers as m
|
| 35 |
+
|
| 36 |
+
class ClassFinder(CSTVisitor):
|
| 37 |
+
"""
|
| 38 |
+
A visitor to find all classes in a python module.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self):
|
| 42 |
+
self.classes: list = []
|
| 43 |
+
self.public_classes: list = []
|
| 44 |
+
self.is_in_class = False
|
| 45 |
+
|
| 46 |
+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
| 47 |
+
"""Record class names. We assume classes always only appear at top-level (i.e. no class definition in function or similar)"""
|
| 48 |
+
self.classes.append(node.name.value)
|
| 49 |
+
self.is_in_class = True
|
| 50 |
+
|
| 51 |
+
def leave_ClassDef(self, node: cst.ClassDef):
|
| 52 |
+
self.is_in_class = False
|
| 53 |
+
|
| 54 |
+
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine):
|
| 55 |
+
"""Record all public classes inside the `__all__` assignment."""
|
| 56 |
+
simple_top_level_assign_structure = m.SimpleStatementLine(
|
| 57 |
+
body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])]
|
| 58 |
+
)
|
| 59 |
+
if not self.is_in_class and m.matches(node, simple_top_level_assign_structure):
|
| 60 |
+
stmt = cast(cst.Assign, node.body[0])
|
| 61 |
+
assigned_variable = cast(cst.Name, stmt.targets[0].target).value
|
| 62 |
+
if assigned_variable == "__all__":
|
| 63 |
+
elements = cast(cst.Tuple, stmt.value).elements
|
| 64 |
+
self.public_classes = [cast(cst.SimpleString, element.value).value for element in elements]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
CURRENT_YEAR = date.today().year
|
| 68 |
+
REPO_PATH = Path(__file__).parents[3]
|
| 69 |
+
|
| 70 |
+
COPYRIGHT = f"""
|
| 71 |
+
# coding=utf-8
|
| 72 |
+
# Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
|
| 73 |
+
#
|
| 74 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 75 |
+
# you may not use this file except in compliance with the License.
|
| 76 |
+
# You may obtain a copy of the License at
|
| 77 |
+
#
|
| 78 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 79 |
+
#
|
| 80 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 81 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 82 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 83 |
+
# See the License for the specific language governing permissions and
|
| 84 |
+
# limitations under the License.
|
| 85 |
+
""".lstrip()
|
| 86 |
+
|
| 87 |
+
### Entrypoint
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def add_new_model_like(
|
| 91 |
+
repo_path: Annotated[
|
| 92 |
+
str | None, typer.Argument(help="When not using an editable install, the path to the Transformers repo.")
|
| 93 |
+
] = None,
|
| 94 |
+
):
|
| 95 |
+
"""
|
| 96 |
+
Add a new model to the library, based on an existing one.
|
| 97 |
+
"""
|
| 98 |
+
(
|
| 99 |
+
old_model_infos,
|
| 100 |
+
new_lowercase_name,
|
| 101 |
+
new_model_paper_name,
|
| 102 |
+
filenames_to_add,
|
| 103 |
+
) = get_user_input()
|
| 104 |
+
|
| 105 |
+
_add_new_model_like_internal(
|
| 106 |
+
repo_path=Path(repo_path) if repo_path is not None else REPO_PATH,
|
| 107 |
+
old_model_infos=old_model_infos,
|
| 108 |
+
new_lowercase_name=new_lowercase_name,
|
| 109 |
+
new_model_paper_name=new_model_paper_name,
|
| 110 |
+
filenames_to_add=filenames_to_add,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
### Core logic
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class ModelInfos:
|
| 118 |
+
"""
|
| 119 |
+
Retrieve the basic information about an existing model classes.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, lowercase_name: str):
|
| 123 |
+
from ..models.auto.configuration_auto import CONFIG_MAPPING_NAMES
|
| 124 |
+
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
|
| 125 |
+
from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES
|
| 126 |
+
from ..models.auto.processing_auto import PROCESSOR_MAPPING_NAMES
|
| 127 |
+
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
|
| 128 |
+
from ..models.auto.video_processing_auto import VIDEO_PROCESSOR_MAPPING_NAMES
|
| 129 |
+
|
| 130 |
+
# Just to make sure it's indeed lowercase
|
| 131 |
+
self.lowercase_name = lowercase_name.lower().replace(" ", "_").replace("-", "_")
|
| 132 |
+
if self.lowercase_name not in CONFIG_MAPPING_NAMES:
|
| 133 |
+
self.lowercase_name.replace("_", "-")
|
| 134 |
+
if self.lowercase_name not in CONFIG_MAPPING_NAMES:
|
| 135 |
+
raise ValueError(f"{lowercase_name} is not a valid model name")
|
| 136 |
+
|
| 137 |
+
self.config_class = CONFIG_MAPPING_NAMES[self.lowercase_name]
|
| 138 |
+
self.camelcase_name = self.config_class.replace("Config", "")
|
| 139 |
+
|
| 140 |
+
# Get tokenizer class
|
| 141 |
+
if self.lowercase_name in TOKENIZER_MAPPING_NAMES:
|
| 142 |
+
self.tokenizer_class = None
|
| 143 |
+
self.fast_tokenizer_class = TOKENIZER_MAPPING_NAMES[self.lowercase_name]
|
| 144 |
+
self.fast_tokenizer_class = (
|
| 145 |
+
None if self.fast_tokenizer_class == "PreTrainedTokenizerFast" else self.fast_tokenizer_class
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
self.tokenizer_class, self.fast_tokenizer_class = None, None
|
| 149 |
+
|
| 150 |
+
self.image_processor_classes = IMAGE_PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None)
|
| 151 |
+
self.video_processor_class = VIDEO_PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None)
|
| 152 |
+
self.feature_extractor_class = FEATURE_EXTRACTOR_MAPPING_NAMES.get(self.lowercase_name, None)
|
| 153 |
+
self.processor_class = PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def add_content_to_file(file_name: str | os.PathLike, new_content: str, add_after: str):
|
| 157 |
+
"""
|
| 158 |
+
A utility to add some content inside a given file.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
file_name (`str` or `os.PathLike`):
|
| 162 |
+
The name of the file in which we want to insert some content.
|
| 163 |
+
new_content (`str`):
|
| 164 |
+
The content to add.
|
| 165 |
+
add_after (`str`):
|
| 166 |
+
The new content is added just after the first instance matching it.
|
| 167 |
+
"""
|
| 168 |
+
with open(file_name, "r", encoding="utf-8") as f:
|
| 169 |
+
old_content = f.read()
|
| 170 |
+
|
| 171 |
+
before, after = old_content.split(add_after, 1)
|
| 172 |
+
new_content = before + add_after + new_content + after
|
| 173 |
+
|
| 174 |
+
with open(file_name, "w", encoding="utf-8") as f:
|
| 175 |
+
f.write(new_content)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def add_model_to_auto_mappings(
|
| 179 |
+
repo_path: Path,
|
| 180 |
+
old_model_infos: ModelInfos,
|
| 181 |
+
new_lowercase_name: str,
|
| 182 |
+
new_model_paper_name: str,
|
| 183 |
+
filenames_to_add: list[tuple[str, bool]],
|
| 184 |
+
):
|
| 185 |
+
"""
|
| 186 |
+
Add a model to all the relevant mappings in the auto module.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
old_model_infos (`ModelInfos`):
|
| 190 |
+
The structure containing the class information of the old model.
|
| 191 |
+
new_lowercase_name (`str`):
|
| 192 |
+
The new lowercase model name.
|
| 193 |
+
new_model_paper_name (`str`):
|
| 194 |
+
The fully cased name (as in the official paper name) of the new model.
|
| 195 |
+
filenames_to_add (`list[tuple[str, bool]]`):
|
| 196 |
+
A list of tuples of all potential filenames to add for a new model, along a boolean flag describing if we
|
| 197 |
+
should add this file or not. For example, [(`modeling_xxx.px`, True), (`configuration_xxx.py`, True), (`tokenization_xxx.py`, False),...]
|
| 198 |
+
"""
|
| 199 |
+
new_cased_name = "".join(x.title() for x in new_lowercase_name.replace("-", "_").split("_"))
|
| 200 |
+
old_lowercase_name = old_model_infos.lowercase_name
|
| 201 |
+
old_cased_name = old_model_infos.camelcase_name
|
| 202 |
+
filenames_to_add = [
|
| 203 |
+
(filename.replace(old_lowercase_name, "auto"), to_add) for filename, to_add in filenames_to_add[1:]
|
| 204 |
+
]
|
| 205 |
+
# fast tokenizer has the same auto mappings as normal ones
|
| 206 |
+
corrected_filenames_to_add = []
|
| 207 |
+
has_image_processor = has_video_processor = False
|
| 208 |
+
for file, to_add in filenames_to_add:
|
| 209 |
+
if "image_processing" in file:
|
| 210 |
+
has_image_processor = True
|
| 211 |
+
elif "video_processing" in file:
|
| 212 |
+
has_video_processor = True
|
| 213 |
+
elif re.search(r"(?:tokenization)|(?:image_processing)_auto_fast.py", file):
|
| 214 |
+
previous_file, previous_to_add = corrected_filenames_to_add[-1]
|
| 215 |
+
corrected_filenames_to_add[-1] = (previous_file, previous_to_add or to_add)
|
| 216 |
+
else:
|
| 217 |
+
corrected_filenames_to_add.append((file, to_add))
|
| 218 |
+
|
| 219 |
+
# Add the config and image/video processor mappings directly as the handling is a bit different
|
| 220 |
+
add_content_to_file(
|
| 221 |
+
repo_path / "src" / "transformers" / "models" / "auto" / "auto_mappings.py",
|
| 222 |
+
new_content=f'("{new_lowercase_name}", "{new_cased_name}Config"),\n ',
|
| 223 |
+
add_after="CONFIG_MAPPING_NAMES = OrderedDict(\n [\n ",
|
| 224 |
+
)
|
| 225 |
+
autofile = (repo_path / "src" / "transformers" / "models" / "auto" / "auto_mappings.py").read_text()
|
| 226 |
+
if has_image_processor:
|
| 227 |
+
matching_lines = re.findall(rf'^\s+\("{old_lowercase_name}",\s+{{[^}}]+}}\),?$', autofile, re.MULTILINE)
|
| 228 |
+
if matching_lines:
|
| 229 |
+
match = matching_lines[0]
|
| 230 |
+
add_content_to_file(
|
| 231 |
+
repo_path / "src" / "transformers" / "models" / "auto" / "auto_mappings.py",
|
| 232 |
+
new_content=match.replace(old_lowercase_name, new_lowercase_name).replace(
|
| 233 |
+
old_cased_name, new_cased_name
|
| 234 |
+
)
|
| 235 |
+
+ "\n",
|
| 236 |
+
add_after="IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(\n [\n",
|
| 237 |
+
)
|
| 238 |
+
if has_video_processor:
|
| 239 |
+
# Extract the VIDEO_PROCESSOR_MAPPING_NAMES block first
|
| 240 |
+
block_match = re.search(
|
| 241 |
+
r"VIDEO_PROCESSOR_MAPPING_NAMES\s*=\s*OrderedDict\(\s*\[(.*?)\]\s*\)", autofile, re.DOTALL
|
| 242 |
+
)
|
| 243 |
+
block = block_match.group(1) # type: ignore
|
| 244 |
+
matching_lines = re.findall(rf'^\s+\("{old_lowercase_name}",\s+"[^"]+"\),?$', block, re.MULTILINE)
|
| 245 |
+
if matching_lines:
|
| 246 |
+
match = matching_lines[0]
|
| 247 |
+
add_content_to_file(
|
| 248 |
+
repo_path / "src" / "transformers" / "models" / "auto" / "auto_mappings.py",
|
| 249 |
+
new_content=match.replace(old_lowercase_name, new_lowercase_name).replace(
|
| 250 |
+
old_cased_name, new_cased_name
|
| 251 |
+
)
|
| 252 |
+
+ "\n",
|
| 253 |
+
add_after="VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(\n [\n",
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
for filename, to_add in corrected_filenames_to_add:
|
| 257 |
+
if to_add:
|
| 258 |
+
# The auto mapping
|
| 259 |
+
filename = filename.replace("_fast.py", ".py")
|
| 260 |
+
file = (repo_path / "src" / "transformers" / "models" / "auto" / filename).read_text()
|
| 261 |
+
# The regex has to be a bit complex like this as the tokenizer mapping has new lines everywhere
|
| 262 |
+
matching_lines = re.findall(
|
| 263 |
+
rf'( {{8,12}}\(\s*"{old_lowercase_name}",.*?\),\n)(?: {{4,12}}\(|\])', file, re.DOTALL
|
| 264 |
+
)
|
| 265 |
+
for match in matching_lines:
|
| 266 |
+
add_content_to_file(
|
| 267 |
+
repo_path / "src" / "transformers" / "models" / "auto" / filename,
|
| 268 |
+
new_content=match.replace(old_lowercase_name, new_lowercase_name).replace(
|
| 269 |
+
old_cased_name, new_cased_name
|
| 270 |
+
),
|
| 271 |
+
add_after=match,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def create_doc_file(new_paper_name: str, public_classes: list[str]):
|
| 276 |
+
"""
|
| 277 |
+
Create a new doc file to fill for the new model.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
new_paper_name (`str`):
|
| 281 |
+
The fully cased name (as in the official paper name) of the new model.
|
| 282 |
+
public_classes (`list[str]`):
|
| 283 |
+
A list of all the public classes that the model will have in the library.
|
| 284 |
+
"""
|
| 285 |
+
added_note = (
|
| 286 |
+
"\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that "
|
| 287 |
+
"may not be rendered properly in your Markdown viewer.\n\n-->\n\n"
|
| 288 |
+
)
|
| 289 |
+
copyright_for_markdown = re.sub(r"# ?", "", COPYRIGHT).replace("coding=utf-8\n", "<!--") + added_note
|
| 290 |
+
|
| 291 |
+
doc_template = textwrap.dedent(
|
| 292 |
+
f"""
|
| 293 |
+
# {new_paper_name}
|
| 294 |
+
|
| 295 |
+
## Overview
|
| 296 |
+
|
| 297 |
+
The {new_paper_name} model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
|
| 298 |
+
<INSERT SHORT SUMMARY HERE>
|
| 299 |
+
|
| 300 |
+
The abstract from the paper is the following:
|
| 301 |
+
|
| 302 |
+
<INSERT PAPER ABSTRACT HERE>
|
| 303 |
+
|
| 304 |
+
Tips:
|
| 305 |
+
|
| 306 |
+
<INSERT TIPS ABOUT MODEL HERE>
|
| 307 |
+
|
| 308 |
+
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
|
| 309 |
+
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
|
| 310 |
+
|
| 311 |
+
## Usage examples
|
| 312 |
+
|
| 313 |
+
<INSERT SOME NICE EXAMPLES HERE>
|
| 314 |
+
|
| 315 |
+
"""
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# Add public classes doc
|
| 319 |
+
doc_for_classes = []
|
| 320 |
+
for class_ in public_classes:
|
| 321 |
+
doc = f"## {class_}\n\n[[autodoc]] {class_}"
|
| 322 |
+
if "Model" in class_:
|
| 323 |
+
doc += "\n - forward"
|
| 324 |
+
doc_for_classes.append(doc)
|
| 325 |
+
|
| 326 |
+
class_doc = "\n\n".join(doc_for_classes)
|
| 327 |
+
|
| 328 |
+
return copyright_for_markdown + doc_template + class_doc
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def insert_model_in_doc_toc(
|
| 332 |
+
repo_path: Path, old_lowercase_name: str, new_lowercase_name: str, new_model_paper_name: str
|
| 333 |
+
):
|
| 334 |
+
"""
|
| 335 |
+
Insert the new model in the doc `_toctree.yaml`, in the same section as the old model.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
old_lowercase_name (`str`):
|
| 339 |
+
The old lowercase model name.
|
| 340 |
+
new_lowercase_name (`str`):
|
| 341 |
+
The new lowercase model name.
|
| 342 |
+
new_model_paper_name (`str`):
|
| 343 |
+
The fully cased name (as in the official paper name) of the new model.
|
| 344 |
+
"""
|
| 345 |
+
toc_file = repo_path / "docs" / "source" / "en" / "_toctree.yml"
|
| 346 |
+
with open(toc_file, "r") as f:
|
| 347 |
+
content = f.read()
|
| 348 |
+
|
| 349 |
+
toc_match = re.search(rf"- local: model_doc/{old_lowercase_name}\n {{8}}title: .*?\n", content)
|
| 350 |
+
if toc_match is None:
|
| 351 |
+
raise ValueError(f"Could not find TOC entry for {old_lowercase_name}")
|
| 352 |
+
old_model_toc = toc_match.group(0)
|
| 353 |
+
new_toc = f" - local: model_doc/{new_lowercase_name}\n title: {new_model_paper_name}\n"
|
| 354 |
+
add_content_to_file(
|
| 355 |
+
repo_path / "docs" / "source" / "en" / "_toctree.yml", new_content=new_toc, add_after=old_model_toc
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def create_init_file(old_lowercase_name: str, new_lowercase_name: str, filenames_to_add: list[tuple[str, bool]]):
|
| 360 |
+
"""
|
| 361 |
+
Create the `__init__.py` file to add in the new model folder.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
old_lowercase_name (`str`):
|
| 365 |
+
The old lowercase model name.
|
| 366 |
+
new_lowercase_name (`str`):
|
| 367 |
+
The new lowercase model name.
|
| 368 |
+
filenames_to_add (`list[tuple[str, bool]]`):
|
| 369 |
+
A list of tuples of all potential filenames to add for a new model, along a boolean flag describing if we
|
| 370 |
+
should add this file or not. For example, [(`modeling_xxx.px`, True), (`configuration_xxx.py`, True), (`tokenization_xxx.py`, False),...]
|
| 371 |
+
"""
|
| 372 |
+
filenames_to_add = [
|
| 373 |
+
(filename.replace(old_lowercase_name, new_lowercase_name).replace(".py", ""), to_add)
|
| 374 |
+
for filename, to_add in filenames_to_add
|
| 375 |
+
]
|
| 376 |
+
imports = "\n ".join(f"from .{file} import *" for file, to_add in filenames_to_add if to_add)
|
| 377 |
+
init_file = COPYRIGHT + textwrap.dedent(
|
| 378 |
+
f"""
|
| 379 |
+
from typing import TYPE_CHECKING
|
| 380 |
+
|
| 381 |
+
from ...utils import _LazyModule
|
| 382 |
+
from ...utils.import_utils import define_import_structure
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
if TYPE_CHECKING:
|
| 386 |
+
{imports}
|
| 387 |
+
else:
|
| 388 |
+
import sys
|
| 389 |
+
|
| 390 |
+
_file = globals()["__file__"]
|
| 391 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
| 392 |
+
"""
|
| 393 |
+
)
|
| 394 |
+
return init_file
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def find_all_classes_from_file(module_name: str) -> set:
|
| 398 |
+
"""
|
| 399 |
+
Find the name of all classes defined in `module_name`, including public ones (defined in `__all__`).
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
module_name (`str`):
|
| 403 |
+
The full path to the python module from which to extract classes.
|
| 404 |
+
"""
|
| 405 |
+
with open(module_name, "r", encoding="utf-8") as file:
|
| 406 |
+
source_code = file.read()
|
| 407 |
+
module = cst.parse_module(source_code)
|
| 408 |
+
visitor = ClassFinder()
|
| 409 |
+
module.visit(visitor)
|
| 410 |
+
return visitor.classes, visitor.public_classes
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def find_modular_structure(
|
| 414 |
+
module_name: Path, old_model_infos: ModelInfos, new_cased_name: str
|
| 415 |
+
) -> tuple[str, str, list]:
|
| 416 |
+
"""
|
| 417 |
+
Extract the modular structure that will be needed to copy a file `module_name` using modular.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
module_name (`str`):
|
| 421 |
+
The full path to the python module to copy with modular.
|
| 422 |
+
old_model_infos (`ModelInfos`):
|
| 423 |
+
The structure containing the class information of the old model.
|
| 424 |
+
new_cased_name (`str`):
|
| 425 |
+
The new cased model name.
|
| 426 |
+
"""
|
| 427 |
+
all_classes, public_classes = find_all_classes_from_file(module_name)
|
| 428 |
+
import_location = ".".join(module_name.parts[-2:]).replace(".py", "")
|
| 429 |
+
old_cased_name = old_model_infos.camelcase_name
|
| 430 |
+
imports = f"from ..{import_location} import {', '.join(class_ for class_ in all_classes)}"
|
| 431 |
+
modular_classes = "\n\n".join(
|
| 432 |
+
f"class {class_.replace(old_cased_name, new_cased_name)}({class_}):\n pass" for class_ in all_classes
|
| 433 |
+
)
|
| 434 |
+
public_classes = [class_.replace(old_cased_name, new_cased_name) for class_ in public_classes]
|
| 435 |
+
return imports, modular_classes, public_classes
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def create_modular_file(
|
| 439 |
+
repo_path: Path,
|
| 440 |
+
old_model_infos: ModelInfos,
|
| 441 |
+
new_lowercase_name: str,
|
| 442 |
+
filenames_to_add: list[tuple[str, bool]],
|
| 443 |
+
) -> str:
|
| 444 |
+
"""
|
| 445 |
+
Create a new modular file which will copy the old model, based on the new name and the different filenames
|
| 446 |
+
(modules) to add.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
old_model_infos (`ModelInfos`):
|
| 450 |
+
The structure containing the class information of the old model.
|
| 451 |
+
new_lowercase_name (`str`):
|
| 452 |
+
The new lowercase model name.
|
| 453 |
+
filenames_to_add (`list[tuple[str, bool]]`):
|
| 454 |
+
A list of tuples of all potential filenames to add for a new model, along a boolean flag describing if we
|
| 455 |
+
should add this file or not. For example, [(`modeling_xxx.px`, True), (`configuration_xxx.py`, True), (`tokenization_xxx.py`, False),...]
|
| 456 |
+
"""
|
| 457 |
+
new_cased_name = "".join(x.title() for x in new_lowercase_name.replace("-", "_").split("_"))
|
| 458 |
+
old_lowercase_name = old_model_infos.lowercase_name
|
| 459 |
+
old_folder_root = repo_path / "src" / "transformers" / "models" / old_lowercase_name
|
| 460 |
+
|
| 461 |
+
# Construct the modular file from the original (old) model, by subclassing each class
|
| 462 |
+
all_imports = ""
|
| 463 |
+
all_bodies = ""
|
| 464 |
+
all_public_classes = []
|
| 465 |
+
for filename, to_add in filenames_to_add:
|
| 466 |
+
if to_add:
|
| 467 |
+
imports, body, public_classes = find_modular_structure(
|
| 468 |
+
old_folder_root / filename, old_model_infos, new_cased_name
|
| 469 |
+
)
|
| 470 |
+
all_imports += f"\n{imports}"
|
| 471 |
+
all_bodies += f"\n\n{body}"
|
| 472 |
+
all_public_classes.extend(public_classes)
|
| 473 |
+
|
| 474 |
+
# Create the __all__ assignment
|
| 475 |
+
public_classes_formatted = "\n ".join(f"{public_class}," for public_class in all_public_classes)
|
| 476 |
+
all_statement = textwrap.dedent(
|
| 477 |
+
f"""
|
| 478 |
+
|
| 479 |
+
__all__ = [
|
| 480 |
+
{public_classes_formatted}
|
| 481 |
+
]
|
| 482 |
+
"""
|
| 483 |
+
)
|
| 484 |
+
# Create the whole modular file
|
| 485 |
+
modular_file = COPYRIGHT + all_imports + all_bodies + all_statement
|
| 486 |
+
# Remove outer explicit quotes "" around the public class names before returning them
|
| 487 |
+
all_public_classes = [public_class.replace('"', "") for public_class in all_public_classes]
|
| 488 |
+
return modular_file, all_public_classes
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def create_test_files(
|
| 492 |
+
repo_path: Path, old_model_infos: ModelInfos, new_lowercase_name, filenames_to_add: list[tuple[str, bool]]
|
| 493 |
+
):
|
| 494 |
+
"""
|
| 495 |
+
Create the test files for the new model. It basically copies over the old test files and adjust the class names.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
old_model_infos (`ModelInfos`):
|
| 499 |
+
The structure containing the class information of the old model.
|
| 500 |
+
new_lowercase_name (`str`):
|
| 501 |
+
The new lowercase model name.
|
| 502 |
+
filenames_to_add (`list[tuple[str, bool]]`):
|
| 503 |
+
A list of tuples of all potential filenames to add for a new model, along a boolean flag describing if we
|
| 504 |
+
should add this file or not. For example, [(`modeling_xxx.px`, True), (`configuration_xxx.py`, True), (`tokenization_xxx.py`, False),...]
|
| 505 |
+
"""
|
| 506 |
+
new_cased_name = "".join(x.title() for x in new_lowercase_name.replace("-", "_").split("_"))
|
| 507 |
+
old_lowercase_name = old_model_infos.lowercase_name
|
| 508 |
+
old_cased_name = old_model_infos.camelcase_name
|
| 509 |
+
filenames_to_add = [
|
| 510 |
+
("test_" + filename.replace(old_lowercase_name, new_lowercase_name), to_add)
|
| 511 |
+
for filename, to_add in filenames_to_add[1:]
|
| 512 |
+
]
|
| 513 |
+
# fast tokenizer/image processor have the same test files as normal ones
|
| 514 |
+
corrected_filenames_to_add = []
|
| 515 |
+
for file, to_add in filenames_to_add:
|
| 516 |
+
if re.search(rf"test_(?:tokenization)|(?:image_processing)_{new_lowercase_name}_fast.py", file):
|
| 517 |
+
previous_file, previous_to_add = corrected_filenames_to_add[-1]
|
| 518 |
+
corrected_filenames_to_add[-1] = (previous_file, previous_to_add or to_add)
|
| 519 |
+
else:
|
| 520 |
+
corrected_filenames_to_add.append((file, to_add))
|
| 521 |
+
|
| 522 |
+
test_files = {}
|
| 523 |
+
for new_file, to_add in corrected_filenames_to_add:
|
| 524 |
+
if to_add:
|
| 525 |
+
original_test_file = new_file.replace(new_lowercase_name, old_lowercase_name)
|
| 526 |
+
original_test_path = repo_path / "tests" / "models" / old_lowercase_name / original_test_file
|
| 527 |
+
# Sometimes, tests may not exist
|
| 528 |
+
if not original_test_path.is_file():
|
| 529 |
+
continue
|
| 530 |
+
with open(original_test_path, "r") as f:
|
| 531 |
+
test_code = f.read()
|
| 532 |
+
# Remove old copyright and add new one
|
| 533 |
+
test_lines = test_code.split("\n")
|
| 534 |
+
idx = 0
|
| 535 |
+
while test_lines[idx].startswith("#"):
|
| 536 |
+
idx += 1
|
| 537 |
+
test_code = COPYRIGHT + "\n".join(test_lines[idx:])
|
| 538 |
+
test_files[new_file] = test_code.replace(old_cased_name, new_cased_name)
|
| 539 |
+
|
| 540 |
+
return test_files
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def _add_new_model_like_internal(
|
| 544 |
+
repo_path: Path,
|
| 545 |
+
old_model_infos: ModelInfos,
|
| 546 |
+
new_lowercase_name: str,
|
| 547 |
+
new_model_paper_name: str,
|
| 548 |
+
filenames_to_add: list[tuple[str, bool]],
|
| 549 |
+
):
|
| 550 |
+
"""
|
| 551 |
+
Creates a new model module like a given model of the Transformers library.
|
| 552 |
+
|
| 553 |
+
Args:
|
| 554 |
+
repo_path (`Path`):
|
| 555 |
+
The path to the root of the Transformers repository.
|
| 556 |
+
old_model_infos (`ModelInfos`):
|
| 557 |
+
The structure containing the class information of the old model.
|
| 558 |
+
new_lowercase_name (`str`):
|
| 559 |
+
The new lowercase model name.
|
| 560 |
+
new_model_paper_name (`str`):
|
| 561 |
+
The fully cased name (as in the official paper name) of the new model.
|
| 562 |
+
filenames_to_add (`list[tuple[str, bool]]`):
|
| 563 |
+
A list of tuples of all potential filenames to add for a new model, along a boolean flag describing if we
|
| 564 |
+
should add this file or not. For example, [(`modeling_xxx.px`, True), (`configuration_xxx.py`, True), (`tokenization_xxx.py`, False),...]
|
| 565 |
+
"""
|
| 566 |
+
# As the import was protected, raise if not present (as it's actually a hard dependency for this command)
|
| 567 |
+
if not is_libcst_available():
|
| 568 |
+
raise ValueError("You need to install `libcst` to run this command -> `pip install libcst`")
|
| 569 |
+
|
| 570 |
+
old_lowercase_name = old_model_infos.lowercase_name
|
| 571 |
+
|
| 572 |
+
# 1. We create the folder for our new model
|
| 573 |
+
new_module_folder = repo_path / "src" / "transformers" / "models" / new_lowercase_name
|
| 574 |
+
os.makedirs(new_module_folder, exist_ok=True)
|
| 575 |
+
|
| 576 |
+
# 2. Create and add the modular file
|
| 577 |
+
modular_file, public_classes = create_modular_file(
|
| 578 |
+
repo_path, old_model_infos, new_lowercase_name, filenames_to_add
|
| 579 |
+
)
|
| 580 |
+
with open(new_module_folder / f"modular_{new_lowercase_name}.py", "w") as f:
|
| 581 |
+
f.write(modular_file)
|
| 582 |
+
|
| 583 |
+
# 3. Create and add the __init__.py
|
| 584 |
+
init_file = create_init_file(old_lowercase_name, new_lowercase_name, filenames_to_add)
|
| 585 |
+
with open(new_module_folder / "__init__.py", "w") as f:
|
| 586 |
+
f.write(init_file)
|
| 587 |
+
|
| 588 |
+
# 4. Add new model to the models init
|
| 589 |
+
add_content_to_file(
|
| 590 |
+
repo_path / "src" / "transformers" / "models" / "__init__.py",
|
| 591 |
+
new_content=f" from .{new_lowercase_name} import *\n",
|
| 592 |
+
add_after="if TYPE_CHECKING:\n",
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
# 5. Add model to auto mappings
|
| 596 |
+
add_model_to_auto_mappings(repo_path, old_model_infos, new_lowercase_name, new_model_paper_name, filenames_to_add)
|
| 597 |
+
|
| 598 |
+
# 6. Add test files
|
| 599 |
+
tests_folder = repo_path / "tests" / "models" / new_lowercase_name
|
| 600 |
+
os.makedirs(tests_folder, exist_ok=True)
|
| 601 |
+
# Add empty __init__.py
|
| 602 |
+
with open(tests_folder / "__init__.py", "w"):
|
| 603 |
+
pass
|
| 604 |
+
test_files = create_test_files(repo_path, old_model_infos, new_lowercase_name, filenames_to_add)
|
| 605 |
+
for filename, content in test_files.items():
|
| 606 |
+
with open(tests_folder / filename, "w") as f:
|
| 607 |
+
f.write(content)
|
| 608 |
+
|
| 609 |
+
# 7. Add doc file
|
| 610 |
+
doc_file = create_doc_file(new_model_paper_name, public_classes)
|
| 611 |
+
with open(repo_path / "docs" / "source" / "en" / "model_doc" / f"{new_lowercase_name}.md", "w") as f:
|
| 612 |
+
f.write(doc_file)
|
| 613 |
+
insert_model_in_doc_toc(repo_path, old_lowercase_name, new_lowercase_name, new_model_paper_name)
|
| 614 |
+
|
| 615 |
+
# 9. Run linters
|
| 616 |
+
model_init_file = repo_path / "src" / "transformers" / "models" / "__init__.py"
|
| 617 |
+
subprocess.run(
|
| 618 |
+
["ruff", "check", new_module_folder, tests_folder, model_init_file, "--fix"],
|
| 619 |
+
cwd=repo_path,
|
| 620 |
+
stdout=subprocess.DEVNULL,
|
| 621 |
+
)
|
| 622 |
+
subprocess.run(
|
| 623 |
+
["ruff", "format", new_module_folder, tests_folder, model_init_file],
|
| 624 |
+
cwd=repo_path,
|
| 625 |
+
stdout=subprocess.DEVNULL,
|
| 626 |
+
)
|
| 627 |
+
subprocess.run(
|
| 628 |
+
["python", "utils/check_doc_toc.py", "--fix_and_overwrite"], cwd=repo_path, stdout=subprocess.DEVNULL
|
| 629 |
+
)
|
| 630 |
+
subprocess.run(["python", "utils/sort_auto_mappings.py"], cwd=repo_path, stdout=subprocess.DEVNULL)
|
| 631 |
+
|
| 632 |
+
# 10. Run the modular conversion
|
| 633 |
+
subprocess.run(
|
| 634 |
+
["python", "utils/modular_model_converter.py", new_lowercase_name], cwd=repo_path, stdout=subprocess.DEVNULL
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def get_user_field(
|
| 639 |
+
question: str,
|
| 640 |
+
default_value: str | None = None,
|
| 641 |
+
convert_to: Callable | None = None,
|
| 642 |
+
fallback_message: str | None = None,
|
| 643 |
+
) -> Any:
|
| 644 |
+
"""
|
| 645 |
+
A utility function that asks a question to the user to get an answer, potentially looping until it gets a valid
|
| 646 |
+
answer.
|
| 647 |
+
|
| 648 |
+
Args:
|
| 649 |
+
question (`str`):
|
| 650 |
+
The question to ask the user.
|
| 651 |
+
default_value (`str`, *optional*):
|
| 652 |
+
A potential default value that will be used when the answer is empty.
|
| 653 |
+
convert_to (`Callable`, *optional*):
|
| 654 |
+
If set, the answer will be passed to this function. If this function raises an error on the provided
|
| 655 |
+
answer, the question will be asked again.
|
| 656 |
+
fallback_message (`str`, *optional*):
|
| 657 |
+
A message that will be displayed each time the question is asked again to the user.
|
| 658 |
+
|
| 659 |
+
Returns:
|
| 660 |
+
`Any`: The answer provided by the user (or the default), passed through the potential conversion function.
|
| 661 |
+
"""
|
| 662 |
+
if not question.endswith(" "):
|
| 663 |
+
question = question + " "
|
| 664 |
+
if default_value is not None:
|
| 665 |
+
question = f"{question} [{default_value}] "
|
| 666 |
+
|
| 667 |
+
valid_answer = False
|
| 668 |
+
while not valid_answer:
|
| 669 |
+
answer = input(question)
|
| 670 |
+
if default_value is not None and len(answer) == 0:
|
| 671 |
+
answer = default_value
|
| 672 |
+
if convert_to is not None:
|
| 673 |
+
try:
|
| 674 |
+
answer = convert_to(answer)
|
| 675 |
+
valid_answer = True
|
| 676 |
+
except Exception:
|
| 677 |
+
valid_answer = False
|
| 678 |
+
else:
|
| 679 |
+
valid_answer = True
|
| 680 |
+
|
| 681 |
+
if not valid_answer:
|
| 682 |
+
print(fallback_message)
|
| 683 |
+
|
| 684 |
+
return answer
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def convert_to_bool(x: str) -> bool:
|
| 688 |
+
"""
|
| 689 |
+
Converts a string to a bool.
|
| 690 |
+
"""
|
| 691 |
+
if x.lower() in ["1", "y", "yes", "true"]:
|
| 692 |
+
return True
|
| 693 |
+
if x.lower() in ["0", "n", "no", "false"]:
|
| 694 |
+
return False
|
| 695 |
+
raise ValueError(f"{x} is not a value that can be converted to a bool.")
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def get_user_input():
|
| 699 |
+
"""
|
| 700 |
+
Ask the user for the necessary inputs to add the new model.
|
| 701 |
+
"""
|
| 702 |
+
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
|
| 703 |
+
|
| 704 |
+
model_types = list(CONFIG_MAPPING_NAMES.keys())
|
| 705 |
+
|
| 706 |
+
# Get old model type
|
| 707 |
+
valid_model_type = False
|
| 708 |
+
while not valid_model_type:
|
| 709 |
+
old_model_type = input(
|
| 710 |
+
"What model would you like to duplicate? Please provide it as lowercase, e.g. `llama`): "
|
| 711 |
+
)
|
| 712 |
+
if old_model_type in model_types:
|
| 713 |
+
valid_model_type = True
|
| 714 |
+
else:
|
| 715 |
+
print(f"{old_model_type} is not a valid model type.")
|
| 716 |
+
near_choices = difflib.get_close_matches(old_model_type, model_types)
|
| 717 |
+
if len(near_choices) >= 1:
|
| 718 |
+
if len(near_choices) > 1:
|
| 719 |
+
near_choices = " or ".join(near_choices)
|
| 720 |
+
print(f"Did you mean {near_choices}?")
|
| 721 |
+
|
| 722 |
+
old_model_infos = ModelInfos(old_model_type)
|
| 723 |
+
|
| 724 |
+
# Ask for the new model name
|
| 725 |
+
new_lowercase_name = get_user_field(
|
| 726 |
+
"What is the new model name? Please provide it as snake lowercase, e.g. `new_model`?"
|
| 727 |
+
)
|
| 728 |
+
new_model_paper_name = get_user_field(
|
| 729 |
+
"What is the fully cased name you would like to appear in the doc (e.g. `NeW ModEl`)? ",
|
| 730 |
+
default_value="".join(x.title() for x in new_lowercase_name.split("_")),
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# Ask if we want to add individual processor classes as well
|
| 734 |
+
add_tokenizer = False
|
| 735 |
+
add_fast_tokenizer = False
|
| 736 |
+
add_image_processor = False
|
| 737 |
+
add_video_processor = False
|
| 738 |
+
add_feature_extractor = False
|
| 739 |
+
add_processor = False
|
| 740 |
+
if old_model_infos.tokenizer_class is not None:
|
| 741 |
+
add_tokenizer = get_user_field(
|
| 742 |
+
f"Do you want to create a new tokenizer? If `no`, it will use the same as {old_model_type} (y/n)?",
|
| 743 |
+
convert_to=convert_to_bool,
|
| 744 |
+
fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
|
| 745 |
+
)
|
| 746 |
+
if old_model_infos.fast_tokenizer_class is not None:
|
| 747 |
+
add_fast_tokenizer = get_user_field(
|
| 748 |
+
f"Do you want to create a new fast tokenizer? If `no`, it will use the same as {old_model_type} (y/n)?",
|
| 749 |
+
convert_to=convert_to_bool,
|
| 750 |
+
fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
|
| 751 |
+
)
|
| 752 |
+
if old_model_infos.image_processor_classes is not None:
|
| 753 |
+
add_image_processor = get_user_field(
|
| 754 |
+
f"Do you want to create a new image processor? If `no`, it will use the same as {old_model_type} (y/n)?",
|
| 755 |
+
convert_to=convert_to_bool,
|
| 756 |
+
fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
|
| 757 |
+
)
|
| 758 |
+
if old_model_infos.video_processor_class is not None:
|
| 759 |
+
add_video_processor = get_user_field(
|
| 760 |
+
f"Do you want to create a new video processor? If `no`, it will use the same as {old_model_type} (y/n)?",
|
| 761 |
+
convert_to=convert_to_bool,
|
| 762 |
+
fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
|
| 763 |
+
)
|
| 764 |
+
if old_model_infos.feature_extractor_class is not None:
|
| 765 |
+
add_feature_extractor = get_user_field(
|
| 766 |
+
f"Do you want to create a new feature extractor? If `no`, it will use the same as {old_model_type} (y/n)?",
|
| 767 |
+
convert_to=convert_to_bool,
|
| 768 |
+
fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
|
| 769 |
+
)
|
| 770 |
+
if old_model_infos.processor_class is not None:
|
| 771 |
+
add_processor = get_user_field(
|
| 772 |
+
f"Do you want to create a new processor? If `no`, it will use the same as {old_model_type} (y/n)?",
|
| 773 |
+
convert_to=convert_to_bool,
|
| 774 |
+
fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
old_lowercase_name = old_model_infos.lowercase_name
|
| 778 |
+
# A list of the old filenames, along whether we should copy them or not
|
| 779 |
+
filenames_to_add = (
|
| 780 |
+
(f"configuration_{old_lowercase_name}.py", True),
|
| 781 |
+
(f"modeling_{old_lowercase_name}.py", True),
|
| 782 |
+
(f"tokenization_{old_lowercase_name}.py", add_tokenizer),
|
| 783 |
+
(f"tokenization_{old_lowercase_name}_fast.py", add_fast_tokenizer),
|
| 784 |
+
(f"image_processing_{old_lowercase_name}.py", add_image_processor),
|
| 785 |
+
(f"video_processing_{old_lowercase_name}.py", add_video_processor),
|
| 786 |
+
(f"feature_extraction_{old_lowercase_name}.py", add_feature_extractor),
|
| 787 |
+
(f"processing_{old_lowercase_name}.py", add_processor),
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
return old_model_infos, new_lowercase_name, new_model_paper_name, filenames_to_add
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/chat.py
ADDED
|
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import asyncio
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import platform
|
| 18 |
+
import re
|
| 19 |
+
import string
|
| 20 |
+
import time
|
| 21 |
+
from collections.abc import AsyncIterator, Awaitable
|
| 22 |
+
from typing import Annotated, Any
|
| 23 |
+
from urllib.parse import urljoin, urlparse
|
| 24 |
+
|
| 25 |
+
import httpx
|
| 26 |
+
import requests
|
| 27 |
+
import typer
|
| 28 |
+
import yaml
|
| 29 |
+
from huggingface_hub import AsyncInferenceClient, ChatCompletionStreamOutput
|
| 30 |
+
|
| 31 |
+
from transformers import GenerationConfig
|
| 32 |
+
from transformers.utils import is_rich_available
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
import readline # noqa importing this enables GNU readline capabilities
|
| 37 |
+
except ImportError:
|
| 38 |
+
# some platforms may not support readline: https://docs.python.org/3/library/readline.html
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
if platform.system() != "Windows":
|
| 42 |
+
import pwd
|
| 43 |
+
|
| 44 |
+
if is_rich_available():
|
| 45 |
+
from rich import filesize
|
| 46 |
+
from rich.console import Console
|
| 47 |
+
from rich.live import Live
|
| 48 |
+
from rich.markdown import Markdown
|
| 49 |
+
from rich.progress import BarColumn, Progress, ProgressColumn, TextColumn, TimeElapsedColumn
|
| 50 |
+
from rich.text import Text
|
| 51 |
+
|
| 52 |
+
DEFAULT_HTTP_ENDPOINT = {"hostname": "localhost", "port": 8000}
|
| 53 |
+
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
|
| 54 |
+
ALLOWED_VALUE_CHARS = set(
|
| 55 |
+
string.ascii_letters + string.digits + string.whitespace + r".!\"#$%&'()*+,\-/:<=>?@[]^_`{|}~"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
DEFAULT_EXAMPLES = {
|
| 59 |
+
"llama": {"text": "There is a Llama in my lawn, how can I get rid of it?"},
|
| 60 |
+
"code": {
|
| 61 |
+
"text": (
|
| 62 |
+
"Write a Python function that integrates any Python function f(x) numerically over an arbitrary "
|
| 63 |
+
"interval [x_start, x_end]."
|
| 64 |
+
),
|
| 65 |
+
},
|
| 66 |
+
"helicopter": {"text": "How many helicopters can a human eat in one sitting?"},
|
| 67 |
+
"numbers": {"text": "Count to 10 but skip every number ending with an 'e'"},
|
| 68 |
+
"birds": {"text": "Why aren't birds real?"},
|
| 69 |
+
"socks": {"text": "Why is it important to eat socks after meditating?"},
|
| 70 |
+
"numbers2": {"text": "Which number is larger, 9.9 or 9.11?"},
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
# Printed at the start of a chat session
|
| 74 |
+
HELP_STRING_MINIMAL = """
|
| 75 |
+
|
| 76 |
+
**TRANSFORMERS CHAT INTERFACE**
|
| 77 |
+
|
| 78 |
+
Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
|
| 79 |
+
- **!help**: shows all available commands (set generation settings, save chat, etc.)
|
| 80 |
+
- **!status**: shows the current status of the model and generation settings
|
| 81 |
+
- **!clear**: clears the current conversation and starts a new one
|
| 82 |
+
- **!exit**: closes the interface
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Printed when the user types `help` in the chat session
|
| 87 |
+
HELP_STRING = f"""
|
| 88 |
+
|
| 89 |
+
**TRANSFORMERS CHAT INTERFACE HELP**
|
| 90 |
+
|
| 91 |
+
Full command list:
|
| 92 |
+
- **!help**: shows this help message
|
| 93 |
+
- **!clear**: clears the current conversation and starts a new one
|
| 94 |
+
- **!status**: shows the current status of the model and generation settings
|
| 95 |
+
- **!example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input.
|
| 96 |
+
Available example names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
|
| 97 |
+
- **!set {{ARG_1}}={{VALUE_1}} {{ARG_2}}={{VALUE_2}}** ...: changes the system prompt or generation settings (multiple
|
| 98 |
+
settings are separated by a space). Accepts the same flags and format as the `generate_flags` CLI argument.
|
| 99 |
+
If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options
|
| 100 |
+
- **!save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
|
| 101 |
+
`./chat_history/{{MODEL_ID}}/chat_{{DATETIME}}.yaml` or `{{SAVE_NAME}}` if provided
|
| 102 |
+
- **!exit**: closes the interface
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class RichInterface:
|
| 107 |
+
def __init__(self, model_id: str, user_id: str, base_url: str):
|
| 108 |
+
self._console = Console()
|
| 109 |
+
self.model_id = model_id
|
| 110 |
+
self.user_id = user_id
|
| 111 |
+
self.base_url = base_url
|
| 112 |
+
|
| 113 |
+
async def stream_output(
|
| 114 |
+
self, stream: Awaitable[AsyncIterator[ChatCompletionStreamOutput]]
|
| 115 |
+
) -> tuple[str, str | Any | None]:
|
| 116 |
+
self._console.print(f"[bold blue]<{self.model_id}>:")
|
| 117 |
+
with Live(console=self._console, refresh_per_second=4) as live:
|
| 118 |
+
text = ""
|
| 119 |
+
completion_tokens = 0
|
| 120 |
+
start_time = time.time()
|
| 121 |
+
finish_reason: str | None = None
|
| 122 |
+
async for token in await stream:
|
| 123 |
+
outputs = token.choices[0].delta.content
|
| 124 |
+
finish_reason = getattr(token.choices[0], "finish_reason", finish_reason)
|
| 125 |
+
|
| 126 |
+
usage = getattr(token, "usage", None)
|
| 127 |
+
if usage is not None:
|
| 128 |
+
completion_tokens = getattr(usage, "completion_tokens", completion_tokens)
|
| 129 |
+
|
| 130 |
+
if not outputs:
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
# Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown.
|
| 134 |
+
# It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>)
|
| 135 |
+
outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs)
|
| 136 |
+
|
| 137 |
+
text += outputs
|
| 138 |
+
# Render the accumulated text as Markdown
|
| 139 |
+
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
| 140 |
+
# in rich. The chatbots output treat "\n" as a new line for
|
| 141 |
+
# better compatibility with real-world text. However, rendering
|
| 142 |
+
# in markdown would break the format. It is because standard markdown
|
| 143 |
+
# treat a single "\n" in normal text as a space.
|
| 144 |
+
# Our workaround is adding two spaces at the end of each line.
|
| 145 |
+
# This is not a perfect solution, as it would
|
| 146 |
+
# introduce trailing spaces (only) in code block, but it works well
|
| 147 |
+
# especially for console output, because in general the console does not
|
| 148 |
+
# care about trailing spaces.
|
| 149 |
+
|
| 150 |
+
lines = []
|
| 151 |
+
for line in text.splitlines():
|
| 152 |
+
lines.append(line)
|
| 153 |
+
if line.startswith("```"):
|
| 154 |
+
# Code block marker - do not add trailing spaces, as it would
|
| 155 |
+
# break the syntax highlighting
|
| 156 |
+
lines.append("\n")
|
| 157 |
+
else:
|
| 158 |
+
lines.append(" \n")
|
| 159 |
+
|
| 160 |
+
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
|
| 161 |
+
|
| 162 |
+
# Update the Live console output
|
| 163 |
+
live.update(markdown, refresh=True)
|
| 164 |
+
|
| 165 |
+
elapsed = time.time() - start_time
|
| 166 |
+
if elapsed > 0 and completion_tokens > 0:
|
| 167 |
+
tok_per_sec = completion_tokens / elapsed
|
| 168 |
+
self._console.print()
|
| 169 |
+
self._console.print(f"[dim]{completion_tokens} tokens in {elapsed:.1f}s ({tok_per_sec:.1f} tok/s)[/dim]")
|
| 170 |
+
self._console.print()
|
| 171 |
+
|
| 172 |
+
return text, finish_reason
|
| 173 |
+
|
| 174 |
+
def input(self) -> str:
|
| 175 |
+
"""Gets user input from the console."""
|
| 176 |
+
input = self._console.input(f"[bold red]<{self.user_id}>:\n")
|
| 177 |
+
self._console.print()
|
| 178 |
+
return input
|
| 179 |
+
|
| 180 |
+
def clear(self):
|
| 181 |
+
"""Clears the console."""
|
| 182 |
+
self._console.clear()
|
| 183 |
+
|
| 184 |
+
def print_user_message(self, text: str):
|
| 185 |
+
"""Prints a user message to the console."""
|
| 186 |
+
self._console.print(f"[bold red]<{self.user_id}>:[/ bold red]\n{text}")
|
| 187 |
+
self._console.print()
|
| 188 |
+
|
| 189 |
+
def print_color(self, text: str, color: str):
|
| 190 |
+
"""Prints text in a given color to the console."""
|
| 191 |
+
self._console.print(f"[bold {color}]{text}")
|
| 192 |
+
self._console.print()
|
| 193 |
+
|
| 194 |
+
def confirm(self, message: str, default: bool = False) -> bool:
|
| 195 |
+
"""Displays a yes/no prompt to the user, returning True for confirmation."""
|
| 196 |
+
default_hint = "Y/n" if default else "y/N"
|
| 197 |
+
response = self._console.input(f"[bold yellow]{message} ({default_hint}): ")
|
| 198 |
+
self._console.print()
|
| 199 |
+
|
| 200 |
+
response = response.strip().lower()
|
| 201 |
+
if not response:
|
| 202 |
+
return default
|
| 203 |
+
|
| 204 |
+
return response in {"y", "yes"}
|
| 205 |
+
|
| 206 |
+
def print_help(self, minimal: bool = False):
|
| 207 |
+
"""Prints the help message to the console."""
|
| 208 |
+
self._console.print(Markdown(HELP_STRING_MINIMAL if minimal else HELP_STRING))
|
| 209 |
+
self._console.print()
|
| 210 |
+
|
| 211 |
+
def print_model_load(self, model: str):
|
| 212 |
+
response = requests.post(f"{self.base_url.rstrip('/')}/load_model", json={"model": model}, stream=True)
|
| 213 |
+
response.raise_for_status()
|
| 214 |
+
|
| 215 |
+
class StatsColumn(ProgressColumn):
|
| 216 |
+
def render(self, task):
|
| 217 |
+
if not task.total:
|
| 218 |
+
return Text("")
|
| 219 |
+
|
| 220 |
+
if task.fields.get("unit") == "bytes":
|
| 221 |
+
done = filesize.decimal(int(task.completed))
|
| 222 |
+
tot = filesize.decimal(int(task.total))
|
| 223 |
+
speed = f" {filesize.decimal(int(task.speed))}/s" if task.speed else ""
|
| 224 |
+
|
| 225 |
+
if task.time_remaining is not None:
|
| 226 |
+
eta = f" {int(task.time_remaining // 60)}:{int(task.time_remaining % 60):02d}"
|
| 227 |
+
else:
|
| 228 |
+
eta = ""
|
| 229 |
+
|
| 230 |
+
return Text(f"{done}/{tot}{speed}{eta}", style="progress.download")
|
| 231 |
+
return Text(f"{int(task.completed)}/{int(task.total)}")
|
| 232 |
+
|
| 233 |
+
stage_labels = {
|
| 234 |
+
"processor": "Loading processor",
|
| 235 |
+
"config": "Loading config",
|
| 236 |
+
"download": "Downloading files",
|
| 237 |
+
"weights": "Loading into memory",
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
# Include the model name prefix in descriptions only when the terminal is wide enough.
|
| 241 |
+
# The bar, stats, and elapsed columns need ~70 chars; the model prefix needs len(model)+5.
|
| 242 |
+
show_model_prefix = self._console.width >= len(model) + 5 + 70
|
| 243 |
+
|
| 244 |
+
def _label(stage_key):
|
| 245 |
+
stage_text = stage_labels.get(stage_key, stage_key)
|
| 246 |
+
if show_model_prefix:
|
| 247 |
+
return f"{model} → {stage_text}"
|
| 248 |
+
return stage_text
|
| 249 |
+
|
| 250 |
+
progress = Progress(
|
| 251 |
+
TextColumn("[bold]{task.description}"),
|
| 252 |
+
BarColumn(bar_width=40),
|
| 253 |
+
StatsColumn(),
|
| 254 |
+
TimeElapsedColumn(),
|
| 255 |
+
console=self._console,
|
| 256 |
+
)
|
| 257 |
+
task_id = progress.add_task(_label("processor"), total=None)
|
| 258 |
+
cached = False
|
| 259 |
+
|
| 260 |
+
with Live(progress, console=self._console, transient=True):
|
| 261 |
+
for line in response.iter_lines():
|
| 262 |
+
if not line or not line.startswith(b"data: "):
|
| 263 |
+
continue
|
| 264 |
+
event = json.loads(line[6:])
|
| 265 |
+
status = event.get("status")
|
| 266 |
+
|
| 267 |
+
if status == "ready":
|
| 268 |
+
cached = event.get("cached", False)
|
| 269 |
+
break
|
| 270 |
+
|
| 271 |
+
if status == "error":
|
| 272 |
+
raise RuntimeError(event.get("message", "Unknown error"))
|
| 273 |
+
|
| 274 |
+
if status == "loading":
|
| 275 |
+
stage = event.get("stage")
|
| 276 |
+
prog = event.get("progress")
|
| 277 |
+
label = _label(stage)
|
| 278 |
+
|
| 279 |
+
if prog:
|
| 280 |
+
unit = "bytes" if stage == "download" else "items"
|
| 281 |
+
progress.update(
|
| 282 |
+
task_id, description=label, completed=prog["current"], total=prog.get("total"), unit=unit
|
| 283 |
+
)
|
| 284 |
+
else:
|
| 285 |
+
progress.update(task_id, description=label, completed=0, total=None)
|
| 286 |
+
|
| 287 |
+
if cached:
|
| 288 |
+
self._console.print(Markdown(f"_*{model} was already loaded.*_"))
|
| 289 |
+
else:
|
| 290 |
+
self._console.print(Markdown(f"_*{model} is warm.*_"))
|
| 291 |
+
self._console.print()
|
| 292 |
+
|
| 293 |
+
def print_status(self, config: GenerationConfig):
|
| 294 |
+
"""Prints the status of the model and generation settings to the console."""
|
| 295 |
+
self._console.print(f"[bold blue]Model: {self.model_id}\n")
|
| 296 |
+
self._console.print(f"[bold blue]{config}")
|
| 297 |
+
self._console.print()
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class Chat:
|
| 301 |
+
"""Chat with a model from the command line."""
|
| 302 |
+
|
| 303 |
+
# Defining a class to help with internal state but in practice it's just a method to call
|
| 304 |
+
# TODO: refactor into a proper module with helpers + 1 main method
|
| 305 |
+
def __init__(
|
| 306 |
+
self,
|
| 307 |
+
model_id: Annotated[str, typer.Argument(help="ID of the model to use (e.g. 'HuggingFaceTB/SmolLM3-3B').")],
|
| 308 |
+
base_url: Annotated[
|
| 309 |
+
str | None, typer.Argument(help="Base url to connect to (e.g. http://localhost:8000/v1).")
|
| 310 |
+
] = f"http://{DEFAULT_HTTP_ENDPOINT['hostname']}:{DEFAULT_HTTP_ENDPOINT['port']}",
|
| 311 |
+
generate_flags: Annotated[
|
| 312 |
+
list[str] | None,
|
| 313 |
+
typer.Argument(
|
| 314 |
+
help=(
|
| 315 |
+
"Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, "
|
| 316 |
+
"and lists of integers, more advanced parameterization should be set through --generation-config. "
|
| 317 |
+
"Example: `transformers chat <base_url> <model_id> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. "
|
| 318 |
+
"If you're a new user, check this basic flag guide: "
|
| 319 |
+
"https://huggingface.co/docs/transformers/llm_tutorial#common-options"
|
| 320 |
+
)
|
| 321 |
+
),
|
| 322 |
+
] = None,
|
| 323 |
+
# General settings
|
| 324 |
+
user: Annotated[
|
| 325 |
+
str | None,
|
| 326 |
+
typer.Option(help="Username to display in chat interface. Defaults to the current user's name."),
|
| 327 |
+
] = None,
|
| 328 |
+
system_prompt: Annotated[str | None, typer.Option(help="System prompt.")] = None,
|
| 329 |
+
save_folder: Annotated[str, typer.Option(help="Folder to save chat history.")] = "./chat_history/",
|
| 330 |
+
examples_path: Annotated[str | None, typer.Option(help="Path to a yaml file with examples.")] = None,
|
| 331 |
+
# Generation settings
|
| 332 |
+
generation_config: Annotated[
|
| 333 |
+
str | None,
|
| 334 |
+
typer.Option(
|
| 335 |
+
help="Path to a local generation config file or to a HuggingFace repo containing a `generation_config.json` file. Other generation settings passed as CLI arguments will be applied on top of this generation config."
|
| 336 |
+
),
|
| 337 |
+
] = None,
|
| 338 |
+
) -> None:
|
| 339 |
+
"""Chat with a model from the command line."""
|
| 340 |
+
self.base_url = base_url
|
| 341 |
+
|
| 342 |
+
parsed = urlparse(self.base_url)
|
| 343 |
+
if parsed.hostname == DEFAULT_HTTP_ENDPOINT["hostname"] and parsed.port == DEFAULT_HTTP_ENDPOINT["port"]:
|
| 344 |
+
self.check_health(self.base_url)
|
| 345 |
+
|
| 346 |
+
self.model_id = model_id
|
| 347 |
+
self.system_prompt = system_prompt
|
| 348 |
+
self.save_folder = save_folder
|
| 349 |
+
|
| 350 |
+
# Generation settings
|
| 351 |
+
config = load_generation_config(generation_config)
|
| 352 |
+
config.update(do_sample=True, max_new_tokens=256) # some default values
|
| 353 |
+
config.update(**parse_generate_flags(generate_flags))
|
| 354 |
+
self.config = config
|
| 355 |
+
|
| 356 |
+
self.settings = {"base_url": base_url, "model_id": model_id, "config": self.config.to_dict()}
|
| 357 |
+
|
| 358 |
+
# User settings
|
| 359 |
+
self.user = user if user is not None else get_username()
|
| 360 |
+
|
| 361 |
+
# Load examples
|
| 362 |
+
if examples_path:
|
| 363 |
+
with open(examples_path) as f:
|
| 364 |
+
self.examples = yaml.safe_load(f)
|
| 365 |
+
else:
|
| 366 |
+
self.examples = DEFAULT_EXAMPLES
|
| 367 |
+
|
| 368 |
+
# Check requirements
|
| 369 |
+
if not is_rich_available():
|
| 370 |
+
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
|
| 371 |
+
|
| 372 |
+
# Run chat session
|
| 373 |
+
asyncio.run(self._inner_run())
|
| 374 |
+
|
| 375 |
+
@staticmethod
|
| 376 |
+
def check_health(url):
|
| 377 |
+
health_url = urljoin(url + "/", "health")
|
| 378 |
+
try:
|
| 379 |
+
output = httpx.get(health_url)
|
| 380 |
+
if output.status_code != 200:
|
| 381 |
+
raise ValueError(
|
| 382 |
+
f"The server running on {url} returned status code {output.status_code} on health check (/health)."
|
| 383 |
+
)
|
| 384 |
+
except httpx.ConnectError:
|
| 385 |
+
raise ValueError(
|
| 386 |
+
f"No server currently running on {url}. To run a local server, please run `transformers serve` in a"
|
| 387 |
+
f"separate shell. Find more information here: https://huggingface.co/docs/transformers/serving"
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
return True
|
| 391 |
+
|
| 392 |
+
def handle_non_exit_user_commands(
|
| 393 |
+
self,
|
| 394 |
+
user_input: str,
|
| 395 |
+
interface: RichInterface,
|
| 396 |
+
examples: dict[str, dict[str, str]],
|
| 397 |
+
config: GenerationConfig,
|
| 398 |
+
chat: list[dict],
|
| 399 |
+
) -> tuple[list[dict], GenerationConfig]:
|
| 400 |
+
"""
|
| 401 |
+
Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the
|
| 402 |
+
generation config (e.g. set a new flag).
|
| 403 |
+
"""
|
| 404 |
+
valid_command = True
|
| 405 |
+
|
| 406 |
+
if user_input == "!clear":
|
| 407 |
+
chat = new_chat_history(self.system_prompt)
|
| 408 |
+
interface.clear()
|
| 409 |
+
|
| 410 |
+
elif user_input == "!help":
|
| 411 |
+
interface.print_help()
|
| 412 |
+
|
| 413 |
+
elif user_input.startswith("!save") and len(user_input.split()) < 2:
|
| 414 |
+
split_input = user_input.split()
|
| 415 |
+
filename = (
|
| 416 |
+
split_input[1]
|
| 417 |
+
if len(split_input) == 2
|
| 418 |
+
else os.path.join(self.save_folder, self.model_id, f"chat_{time.strftime('%Y-%m-%d_%H-%M-%S')}.json")
|
| 419 |
+
)
|
| 420 |
+
save_chat(filename=filename, chat=chat, settings=self.settings)
|
| 421 |
+
interface.print_color(text=f"Chat saved to {filename}!", color="green")
|
| 422 |
+
|
| 423 |
+
elif user_input.startswith("!set"):
|
| 424 |
+
# splits the new args into a list of strings, each string being a `flag=value` pair (same format as
|
| 425 |
+
# `generate_flags`)
|
| 426 |
+
new_generate_flags = user_input[4:].strip()
|
| 427 |
+
new_generate_flags = new_generate_flags.split()
|
| 428 |
+
# sanity check: each member in the list must have an =
|
| 429 |
+
for flag in new_generate_flags:
|
| 430 |
+
if "=" not in flag:
|
| 431 |
+
interface.print_color(
|
| 432 |
+
text=(
|
| 433 |
+
f"Invalid flag format, missing `=` after `{flag}`. Please use the format "
|
| 434 |
+
"`arg_1=value_1 arg_2=value_2 ...`."
|
| 435 |
+
),
|
| 436 |
+
color="red",
|
| 437 |
+
)
|
| 438 |
+
break
|
| 439 |
+
else:
|
| 440 |
+
# Update config from user flags
|
| 441 |
+
config.update(**parse_generate_flags(new_generate_flags))
|
| 442 |
+
|
| 443 |
+
elif user_input.startswith("!example") and len(user_input.split()) == 2:
|
| 444 |
+
example_name = user_input.split()[1]
|
| 445 |
+
if example_name in examples:
|
| 446 |
+
interface.clear()
|
| 447 |
+
chat = []
|
| 448 |
+
interface.print_user_message(examples[example_name]["text"])
|
| 449 |
+
chat.append({"role": "user", "content": examples[example_name]["text"]})
|
| 450 |
+
else:
|
| 451 |
+
example_error = (
|
| 452 |
+
f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
|
| 453 |
+
)
|
| 454 |
+
interface.print_color(text=example_error, color="red")
|
| 455 |
+
|
| 456 |
+
elif user_input == "!status":
|
| 457 |
+
interface.print_status(config=config)
|
| 458 |
+
|
| 459 |
+
else:
|
| 460 |
+
valid_command = False
|
| 461 |
+
interface.print_color(text=f"'{user_input}' is not a valid command. Showing help message.", color="red")
|
| 462 |
+
interface.print_help()
|
| 463 |
+
|
| 464 |
+
return chat, valid_command, config
|
| 465 |
+
|
| 466 |
+
async def _inner_run(self):
|
| 467 |
+
interface = RichInterface(model_id=self.model_id, user_id=self.user, base_url=self.base_url)
|
| 468 |
+
interface.clear()
|
| 469 |
+
chat = new_chat_history(self.system_prompt)
|
| 470 |
+
|
| 471 |
+
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
|
| 472 |
+
interface.print_help(minimal=True)
|
| 473 |
+
interface.print_model_load(self.model_id)
|
| 474 |
+
|
| 475 |
+
config = self.config
|
| 476 |
+
|
| 477 |
+
async with AsyncInferenceClient(base_url=self.base_url) as client:
|
| 478 |
+
pending_user_input: str | None = None
|
| 479 |
+
while True:
|
| 480 |
+
try:
|
| 481 |
+
if pending_user_input is not None:
|
| 482 |
+
user_input = pending_user_input
|
| 483 |
+
pending_user_input = None
|
| 484 |
+
interface.print_user_message(user_input)
|
| 485 |
+
else:
|
| 486 |
+
user_input = interface.input()
|
| 487 |
+
|
| 488 |
+
# User commands
|
| 489 |
+
if user_input == "!exit":
|
| 490 |
+
break
|
| 491 |
+
|
| 492 |
+
elif user_input == "!clear":
|
| 493 |
+
chat = new_chat_history(self.system_prompt)
|
| 494 |
+
interface.clear()
|
| 495 |
+
continue
|
| 496 |
+
|
| 497 |
+
elif user_input == "!help":
|
| 498 |
+
interface.print_help()
|
| 499 |
+
continue
|
| 500 |
+
|
| 501 |
+
elif user_input.startswith("!save") and len(user_input.split()) < 2:
|
| 502 |
+
split_input = user_input.split()
|
| 503 |
+
filename = (
|
| 504 |
+
split_input[1]
|
| 505 |
+
if len(split_input) == 2
|
| 506 |
+
else os.path.join(
|
| 507 |
+
self.save_folder, self.model_id, f"chat_{time.strftime('%Y-%m-%d_%H-%M-%S')}.json"
|
| 508 |
+
)
|
| 509 |
+
)
|
| 510 |
+
save_chat(filename=filename, chat=chat, settings=self.settings)
|
| 511 |
+
interface.print_color(text=f"Chat saved to {filename}!", color="green")
|
| 512 |
+
continue
|
| 513 |
+
|
| 514 |
+
elif user_input.startswith("!set"):
|
| 515 |
+
# splits the new args into a list of strings, each string being a `flag=value` pair (same format as
|
| 516 |
+
# `generate_flags`)
|
| 517 |
+
new_generate_flags = user_input[4:].strip()
|
| 518 |
+
new_generate_flags = new_generate_flags.split()
|
| 519 |
+
# sanity check: each member in the list must have an =
|
| 520 |
+
for flag in new_generate_flags:
|
| 521 |
+
if "=" not in flag:
|
| 522 |
+
interface.print_color(
|
| 523 |
+
text=(
|
| 524 |
+
f"Invalid flag format, missing `=` after `{flag}`. Please use the format "
|
| 525 |
+
"`arg_1=value_1 arg_2=value_2 ...`."
|
| 526 |
+
),
|
| 527 |
+
color="red",
|
| 528 |
+
)
|
| 529 |
+
break
|
| 530 |
+
else:
|
| 531 |
+
# Update config from user flags
|
| 532 |
+
config.update(**parse_generate_flags(new_generate_flags))
|
| 533 |
+
continue
|
| 534 |
+
|
| 535 |
+
elif user_input.startswith("!example") and len(user_input.split()) == 2:
|
| 536 |
+
example_name = user_input.split()[1]
|
| 537 |
+
if example_name in self.examples:
|
| 538 |
+
interface.clear()
|
| 539 |
+
chat = []
|
| 540 |
+
interface.print_user_message(self.examples[example_name]["text"])
|
| 541 |
+
chat.append({"role": "user", "content": self.examples[example_name]["text"]})
|
| 542 |
+
else:
|
| 543 |
+
example_error = f"Example {example_name} not found in list of available examples: {list(self.examples.keys())}."
|
| 544 |
+
interface.print_color(text=example_error, color="red")
|
| 545 |
+
|
| 546 |
+
elif user_input == "!status":
|
| 547 |
+
interface.print_status(config=config)
|
| 548 |
+
continue
|
| 549 |
+
|
| 550 |
+
elif user_input.startswith("!"):
|
| 551 |
+
interface.print_color(
|
| 552 |
+
text=f"'{user_input}' is not a valid command. Showing help message.", color="red"
|
| 553 |
+
)
|
| 554 |
+
interface.print_help()
|
| 555 |
+
continue
|
| 556 |
+
|
| 557 |
+
else:
|
| 558 |
+
chat.append({"role": "user", "content": user_input})
|
| 559 |
+
|
| 560 |
+
extra_body = {
|
| 561 |
+
"generation_config": config.to_json_string(),
|
| 562 |
+
"model": self.model_id,
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
stream = client.chat_completion(
|
| 566 |
+
chat,
|
| 567 |
+
stream=True,
|
| 568 |
+
model=self.model_id,
|
| 569 |
+
extra_body=extra_body,
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
model_output, finish_reason = await interface.stream_output(stream)
|
| 573 |
+
|
| 574 |
+
chat.append({"role": "assistant", "content": model_output})
|
| 575 |
+
|
| 576 |
+
if finish_reason == "length":
|
| 577 |
+
interface.print_color("Generation stopped after reaching the token limit.", "yellow")
|
| 578 |
+
if interface.confirm("Continue generating?"):
|
| 579 |
+
pending_user_input = "Please continue. Do not repeat text.”"
|
| 580 |
+
continue
|
| 581 |
+
except KeyboardInterrupt:
|
| 582 |
+
break
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def load_generation_config(generation_config: str | None) -> GenerationConfig:
|
| 586 |
+
if generation_config is None:
|
| 587 |
+
return GenerationConfig()
|
| 588 |
+
|
| 589 |
+
if ".json" in generation_config: # is a local file
|
| 590 |
+
dirname = os.path.dirname(generation_config)
|
| 591 |
+
filename = os.path.basename(generation_config)
|
| 592 |
+
return GenerationConfig.from_pretrained(dirname, filename)
|
| 593 |
+
else:
|
| 594 |
+
return GenerationConfig.from_pretrained(generation_config)
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def parse_generate_flags(generate_flags: list[str] | None) -> dict:
|
| 598 |
+
"""Parses the generate flags from the user input into a dictionary of `generate` kwargs."""
|
| 599 |
+
if generate_flags is None or len(generate_flags) == 0:
|
| 600 |
+
return {}
|
| 601 |
+
|
| 602 |
+
# Assumption: `generate_flags` is a list of strings, each string being a `flag=value` pair, that can be parsed
|
| 603 |
+
# into a json string if we:
|
| 604 |
+
# 1. Add quotes around each flag name
|
| 605 |
+
generate_flags_as_dict = {'"' + flag.split("=")[0] + '"': flag.split("=")[1] for flag in generate_flags}
|
| 606 |
+
|
| 607 |
+
# 2. Handle types:
|
| 608 |
+
# 2. a. booleans should be lowercase, None should be null
|
| 609 |
+
generate_flags_as_dict = {
|
| 610 |
+
k: v.lower() if v.lower() in ["true", "false"] else v for k, v in generate_flags_as_dict.items()
|
| 611 |
+
}
|
| 612 |
+
generate_flags_as_dict = {k: "null" if v == "None" else v for k, v in generate_flags_as_dict.items()}
|
| 613 |
+
|
| 614 |
+
# 2. b. strings should be quoted
|
| 615 |
+
def is_number(s: str) -> bool:
|
| 616 |
+
# handle negative numbers
|
| 617 |
+
s = s.removeprefix("-")
|
| 618 |
+
return s.replace(".", "", 1).isdigit()
|
| 619 |
+
|
| 620 |
+
generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()}
|
| 621 |
+
# 2. c. [no processing needed] lists are lists of ints because `generate` doesn't take lists of strings :)
|
| 622 |
+
# We also mention in the help message that we only accept lists of ints for now.
|
| 623 |
+
|
| 624 |
+
# 3. Join the result into a comma separated string
|
| 625 |
+
generate_flags_string = ", ".join([f"{k}: {v}" for k, v in generate_flags_as_dict.items()])
|
| 626 |
+
|
| 627 |
+
# 4. Add the opening/closing brackets
|
| 628 |
+
generate_flags_string = "{" + generate_flags_string + "}"
|
| 629 |
+
|
| 630 |
+
# 5. Remove quotes around boolean/null and around lists
|
| 631 |
+
generate_flags_string = generate_flags_string.replace('"null"', "null")
|
| 632 |
+
generate_flags_string = generate_flags_string.replace('"true"', "true")
|
| 633 |
+
generate_flags_string = generate_flags_string.replace('"false"', "false")
|
| 634 |
+
generate_flags_string = generate_flags_string.replace('"[', "[")
|
| 635 |
+
generate_flags_string = generate_flags_string.replace(']"', "]")
|
| 636 |
+
|
| 637 |
+
# 6. Replace the `=` with `:`
|
| 638 |
+
generate_flags_string = generate_flags_string.replace("=", ":")
|
| 639 |
+
|
| 640 |
+
try:
|
| 641 |
+
processed_generate_flags = json.loads(generate_flags_string)
|
| 642 |
+
except json.JSONDecodeError:
|
| 643 |
+
raise ValueError(
|
| 644 |
+
"Failed to convert `generate_flags` into a valid JSON object."
|
| 645 |
+
"\n`generate_flags` = {generate_flags}"
|
| 646 |
+
"\nConverted JSON string = {generate_flags_string}"
|
| 647 |
+
)
|
| 648 |
+
return processed_generate_flags
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
def new_chat_history(system_prompt: str | None = None) -> list[dict]:
|
| 652 |
+
"""Returns a new chat conversation."""
|
| 653 |
+
return [{"role": "system", "content": system_prompt}] if system_prompt else []
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
def save_chat(filename: str, chat: list[dict], settings: dict) -> str:
|
| 657 |
+
"""Saves the chat history to a file."""
|
| 658 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
| 659 |
+
with open(filename, "w") as f:
|
| 660 |
+
json.dump({"settings": settings, "chat_history": chat}, f, indent=4)
|
| 661 |
+
return os.path.abspath(filename)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def get_username() -> str:
|
| 665 |
+
"""Returns the username of the current user."""
|
| 666 |
+
if platform.system() == "Windows":
|
| 667 |
+
return os.getlogin()
|
| 668 |
+
else:
|
| 669 |
+
return pwd.getpwuid(os.getuid()).pw_name
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
if __name__ == "__main__":
|
| 673 |
+
Chat(model_id="meta-llama/Llama-3.2-3b-Instruct")
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/download.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Annotated
|
| 15 |
+
|
| 16 |
+
import typer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def download(
|
| 20 |
+
model_id: Annotated[str, typer.Argument(help="The model ID to download")],
|
| 21 |
+
cache_dir: Annotated[str | None, typer.Option(help="Directory where to save files.")] = None,
|
| 22 |
+
force_download: Annotated[
|
| 23 |
+
bool, typer.Option(help="If set, the files will be downloaded even if they are already cached locally.")
|
| 24 |
+
] = False,
|
| 25 |
+
trust_remote_code: Annotated[
|
| 26 |
+
bool,
|
| 27 |
+
typer.Option(
|
| 28 |
+
help="Whether or not to allow for custom models defined on the Hub in their own modeling files. Use only if you've reviewed the code as it will execute on your local machine"
|
| 29 |
+
),
|
| 30 |
+
] = False,
|
| 31 |
+
):
|
| 32 |
+
"""Download a model and its tokenizer from the Hub."""
|
| 33 |
+
from ..models.auto import AutoModel, AutoTokenizer
|
| 34 |
+
|
| 35 |
+
AutoModel.from_pretrained(
|
| 36 |
+
model_id, cache_dir=cache_dir, force_download=force_download, trust_remote_code=trust_remote_code
|
| 37 |
+
)
|
| 38 |
+
AutoTokenizer.from_pretrained(
|
| 39 |
+
model_id, cache_dir=cache_dir, force_download=force_download, trust_remote_code=trust_remote_code
|
| 40 |
+
)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/serve.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
CLI entry point for `transformers serve`.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import asyncio
|
| 19 |
+
import enum
|
| 20 |
+
import json
|
| 21 |
+
import threading
|
| 22 |
+
from typing import Annotated
|
| 23 |
+
|
| 24 |
+
import typer
|
| 25 |
+
|
| 26 |
+
from transformers.utils import logging
|
| 27 |
+
from transformers.utils.import_utils import is_serve_available
|
| 28 |
+
|
| 29 |
+
from .serving.utils import set_torch_seed
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ReasoningMode(str, enum.Enum):
|
| 36 |
+
ON = "on"
|
| 37 |
+
OFF = "off"
|
| 38 |
+
AUTO = "auto"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Serve:
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
force_model: Annotated[str | None, typer.Argument(help="Model to preload and use for all requests.")] = None,
|
| 45 |
+
# Model options
|
| 46 |
+
continuous_batching: Annotated[
|
| 47 |
+
bool,
|
| 48 |
+
typer.Option(help="Enable continuous batching with paged attention. Configure with --cb-* flags."),
|
| 49 |
+
] = False,
|
| 50 |
+
attn_implementation: Annotated[
|
| 51 |
+
str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).")
|
| 52 |
+
] = None,
|
| 53 |
+
compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False,
|
| 54 |
+
quantization: Annotated[
|
| 55 |
+
str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.")
|
| 56 |
+
] = None,
|
| 57 |
+
reasoning: Annotated[
|
| 58 |
+
ReasoningMode,
|
| 59 |
+
typer.Option(
|
| 60 |
+
help=(
|
| 61 |
+
"Reasoning mode. 'auto' uses the chat template default. Only applies to models that "
|
| 62 |
+
"support reasoning via their chat template (e.g. Qwen3, Gemma 4) — for other models "
|
| 63 |
+
"this flag has no effect."
|
| 64 |
+
)
|
| 65 |
+
),
|
| 66 |
+
] = ReasoningMode.AUTO,
|
| 67 |
+
chat_template_kwargs: Annotated[
|
| 68 |
+
str | None,
|
| 69 |
+
typer.Option(
|
| 70 |
+
help=(
|
| 71 |
+
"Default JSON kwargs forwarded to apply_chat_template "
|
| 72 |
+
"(e.g. '{\"enable_thinking\": true}'); per-request chat_template_kwargs override these."
|
| 73 |
+
)
|
| 74 |
+
),
|
| 75 |
+
] = None,
|
| 76 |
+
device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto",
|
| 77 |
+
dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto",
|
| 78 |
+
trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False,
|
| 79 |
+
model_timeout: Annotated[
|
| 80 |
+
int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.")
|
| 81 |
+
] = 300,
|
| 82 |
+
# Continuous batching tuning
|
| 83 |
+
cb_block_size: Annotated[
|
| 84 |
+
int | None, typer.Option(help="KV cache block size in tokens for continuous batching.")
|
| 85 |
+
] = None,
|
| 86 |
+
cb_num_blocks: Annotated[
|
| 87 |
+
int | None, typer.Option(help="Number of KV cache blocks for continuous batching.")
|
| 88 |
+
] = None,
|
| 89 |
+
cb_max_batch_tokens: Annotated[
|
| 90 |
+
int | None, typer.Option(help="Maximum tokens per batch for continuous batching.")
|
| 91 |
+
] = None,
|
| 92 |
+
cb_max_memory_percent: Annotated[
|
| 93 |
+
float | None, typer.Option(help="Max GPU memory fraction for KV cache (0.0-1.0).")
|
| 94 |
+
] = None,
|
| 95 |
+
cb_use_cuda_graph: Annotated[
|
| 96 |
+
bool | None, typer.Option(help="Enable CUDA graphs for continuous batching.")
|
| 97 |
+
] = None,
|
| 98 |
+
# Server options
|
| 99 |
+
host: Annotated[str, typer.Option(help="Server listen address.")] = "localhost",
|
| 100 |
+
port: Annotated[int, typer.Option(help="Server listen port.")] = 8000,
|
| 101 |
+
enable_cors: Annotated[bool, typer.Option(help="Enable permissive CORS.")] = False,
|
| 102 |
+
log_level: Annotated[str, typer.Option(help="Logging level (e.g. 'info', 'warning').")] = "warning",
|
| 103 |
+
default_seed: Annotated[int | None, typer.Option(help="Default torch seed.")] = None,
|
| 104 |
+
non_blocking: Annotated[
|
| 105 |
+
bool, typer.Option(hidden=True, help="Run server in a background thread. Used by tests.")
|
| 106 |
+
] = False,
|
| 107 |
+
) -> None:
|
| 108 |
+
if not is_serve_available():
|
| 109 |
+
raise ImportError("Missing dependencies for serving. Install with `pip install transformers[serving]`")
|
| 110 |
+
|
| 111 |
+
import uvicorn
|
| 112 |
+
|
| 113 |
+
from .serving.chat_completion import ChatCompletionHandler
|
| 114 |
+
from .serving.completion import CompletionHandler
|
| 115 |
+
from .serving.model_manager import ModelManager
|
| 116 |
+
from .serving.response import ResponseHandler
|
| 117 |
+
from .serving.server import build_server
|
| 118 |
+
from .serving.transcription import TranscriptionHandler
|
| 119 |
+
from .serving.utils import GenerationState
|
| 120 |
+
|
| 121 |
+
# Seed
|
| 122 |
+
if default_seed is not None:
|
| 123 |
+
set_torch_seed(default_seed)
|
| 124 |
+
|
| 125 |
+
# Logging
|
| 126 |
+
transformers_logger = logging.get_logger("transformers")
|
| 127 |
+
transformers_logger.setLevel(logging.log_levels[log_level.lower()])
|
| 128 |
+
|
| 129 |
+
self._model_manager = ModelManager(
|
| 130 |
+
device=device,
|
| 131 |
+
dtype=dtype,
|
| 132 |
+
trust_remote_code=trust_remote_code,
|
| 133 |
+
attn_implementation=attn_implementation,
|
| 134 |
+
quantization=quantization,
|
| 135 |
+
model_timeout=model_timeout,
|
| 136 |
+
force_model=force_model,
|
| 137 |
+
)
|
| 138 |
+
from transformers import ContinuousBatchingConfig
|
| 139 |
+
|
| 140 |
+
cb_kwargs = {
|
| 141 |
+
k: v
|
| 142 |
+
for k, v in {
|
| 143 |
+
"block_size": cb_block_size,
|
| 144 |
+
"num_blocks": cb_num_blocks,
|
| 145 |
+
"max_batch_tokens": cb_max_batch_tokens,
|
| 146 |
+
"max_memory_percent": cb_max_memory_percent,
|
| 147 |
+
"use_cuda_graph": cb_use_cuda_graph,
|
| 148 |
+
}.items()
|
| 149 |
+
if v is not None
|
| 150 |
+
}
|
| 151 |
+
cb_config = ContinuousBatchingConfig(**cb_kwargs) if cb_kwargs else None
|
| 152 |
+
self._generation_state = GenerationState(
|
| 153 |
+
continuous_batching=continuous_batching,
|
| 154 |
+
compile=compile,
|
| 155 |
+
cb_config=cb_config,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if chat_template_kwargs:
|
| 159 |
+
chat_template_kwargs = json.loads(chat_template_kwargs)
|
| 160 |
+
if not isinstance(chat_template_kwargs, dict):
|
| 161 |
+
raise typer.BadParameter("--chat-template-kwargs must be a JSON object")
|
| 162 |
+
else:
|
| 163 |
+
chat_template_kwargs = {}
|
| 164 |
+
|
| 165 |
+
if reasoning == ReasoningMode.ON:
|
| 166 |
+
chat_template_kwargs["enable_thinking"] = True
|
| 167 |
+
elif reasoning == ReasoningMode.OFF:
|
| 168 |
+
chat_template_kwargs["enable_thinking"] = False
|
| 169 |
+
|
| 170 |
+
self._chat_handler = ChatCompletionHandler(
|
| 171 |
+
model_manager=self._model_manager,
|
| 172 |
+
generation_state=self._generation_state,
|
| 173 |
+
chat_template_kwargs=chat_template_kwargs,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self._completion_handler = CompletionHandler(
|
| 177 |
+
model_manager=self._model_manager,
|
| 178 |
+
generation_state=self._generation_state,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
self._response_handler = ResponseHandler(
|
| 182 |
+
model_manager=self._model_manager,
|
| 183 |
+
generation_state=self._generation_state,
|
| 184 |
+
chat_template_kwargs=chat_template_kwargs,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
self._transcription_handler = TranscriptionHandler(self._model_manager, self._generation_state)
|
| 188 |
+
|
| 189 |
+
app = build_server(
|
| 190 |
+
self._model_manager,
|
| 191 |
+
self._chat_handler,
|
| 192 |
+
completion_handler=self._completion_handler,
|
| 193 |
+
response_handler=self._response_handler,
|
| 194 |
+
transcription_handler=self._transcription_handler,
|
| 195 |
+
generation_state=self._generation_state,
|
| 196 |
+
enable_cors=enable_cors,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
| 200 |
+
self.server = uvicorn.Server(config)
|
| 201 |
+
|
| 202 |
+
if non_blocking:
|
| 203 |
+
self.start_server()
|
| 204 |
+
else:
|
| 205 |
+
self.server.run()
|
| 206 |
+
|
| 207 |
+
def start_server(self):
|
| 208 |
+
def _run():
|
| 209 |
+
loop = asyncio.new_event_loop()
|
| 210 |
+
asyncio.set_event_loop(loop)
|
| 211 |
+
loop.run_until_complete(self.server.serve())
|
| 212 |
+
|
| 213 |
+
self._thread = threading.Thread(target=_run, name="uvicorn-thread", daemon=False)
|
| 214 |
+
self._thread.start()
|
| 215 |
+
|
| 216 |
+
def reset_loaded_models(self):
|
| 217 |
+
"""Clear all loaded models from memory."""
|
| 218 |
+
self._model_manager.shutdown()
|
| 219 |
+
|
| 220 |
+
def kill_server(self):
|
| 221 |
+
self._generation_state.shutdown()
|
| 222 |
+
self._model_manager.shutdown()
|
| 223 |
+
if not self._thread or not self._thread.is_alive():
|
| 224 |
+
return
|
| 225 |
+
self.server.should_exit = True
|
| 226 |
+
self._thread.join(timeout=2)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
Serve.__doc__ = """
|
| 230 |
+
Run a FastAPI server to serve models on-demand with an OpenAI compatible API.
|
| 231 |
+
Models will be loaded and unloaded automatically based on usage and a timeout.
|
| 232 |
+
|
| 233 |
+
\b
|
| 234 |
+
Endpoints:
|
| 235 |
+
POST /v1/chat/completions — Chat completions (streaming + non-streaming).
|
| 236 |
+
POST /v1/completions — Legacy text completions from a prompt.
|
| 237 |
+
GET /v1/models — Lists available models.
|
| 238 |
+
GET /health — Health check.
|
| 239 |
+
|
| 240 |
+
Requires FastAPI and Uvicorn: pip install transformers[serving]
|
| 241 |
+
"""
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/system.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Contains commands to print information about the environment and version.
|
| 15 |
+
|
| 16 |
+
Usage:
|
| 17 |
+
transformers env
|
| 18 |
+
transformers version
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import contextlib
|
| 22 |
+
import io
|
| 23 |
+
import os
|
| 24 |
+
import platform
|
| 25 |
+
from typing import Annotated
|
| 26 |
+
|
| 27 |
+
import huggingface_hub
|
| 28 |
+
import typer
|
| 29 |
+
|
| 30 |
+
from .. import __version__
|
| 31 |
+
from ..integrations.deepspeed import is_deepspeed_available
|
| 32 |
+
from ..utils import (
|
| 33 |
+
is_accelerate_available,
|
| 34 |
+
is_torch_available,
|
| 35 |
+
is_torch_hpu_available,
|
| 36 |
+
is_torch_npu_available,
|
| 37 |
+
is_torch_xpu_available,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def env(
|
| 42 |
+
accelerate_config_file: Annotated[
|
| 43 |
+
str | None,
|
| 44 |
+
typer.Argument(help="The accelerate config file to use for the default values in the launching script."),
|
| 45 |
+
] = None,
|
| 46 |
+
) -> None:
|
| 47 |
+
"""Print information about the environment."""
|
| 48 |
+
import safetensors
|
| 49 |
+
|
| 50 |
+
# TODO: remove hasattr guard once safetensors >= 0.8.0 is released (adds __version__)
|
| 51 |
+
safetensors_version = safetensors.__version__ if hasattr(safetensors, "__version__") else "unknown"
|
| 52 |
+
|
| 53 |
+
accelerate_version = "not installed"
|
| 54 |
+
accelerate_config = accelerate_config_str = "not found"
|
| 55 |
+
|
| 56 |
+
if is_accelerate_available():
|
| 57 |
+
import accelerate
|
| 58 |
+
from accelerate.commands.config import default_config_file, load_config_from_file
|
| 59 |
+
|
| 60 |
+
accelerate_version = accelerate.__version__
|
| 61 |
+
# Get the default from the config file.
|
| 62 |
+
if accelerate_config_file is not None or os.path.isfile(default_config_file):
|
| 63 |
+
accelerate_config = load_config_from_file(accelerate_config_file).to_dict()
|
| 64 |
+
|
| 65 |
+
accelerate_config_str = (
|
| 66 |
+
"\n".join([f"\t- {prop}: {val}" for prop, val in accelerate_config.items()])
|
| 67 |
+
if isinstance(accelerate_config, dict)
|
| 68 |
+
else f"\t{accelerate_config}"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
pt_version = "not installed"
|
| 72 |
+
pt_cuda_available = "NA"
|
| 73 |
+
pt_accelerator = "NA"
|
| 74 |
+
if is_torch_available():
|
| 75 |
+
import torch
|
| 76 |
+
|
| 77 |
+
pt_version = torch.__version__
|
| 78 |
+
pt_cuda_available = torch.cuda.is_available()
|
| 79 |
+
pt_xpu_available = is_torch_xpu_available()
|
| 80 |
+
pt_npu_available = is_torch_npu_available()
|
| 81 |
+
pt_hpu_available = is_torch_hpu_available()
|
| 82 |
+
|
| 83 |
+
if pt_cuda_available:
|
| 84 |
+
pt_accelerator = "CUDA"
|
| 85 |
+
elif pt_xpu_available:
|
| 86 |
+
pt_accelerator = "XPU"
|
| 87 |
+
elif pt_npu_available:
|
| 88 |
+
pt_accelerator = "NPU"
|
| 89 |
+
elif pt_hpu_available:
|
| 90 |
+
pt_accelerator = "HPU"
|
| 91 |
+
|
| 92 |
+
deepspeed_version = "not installed"
|
| 93 |
+
if is_deepspeed_available():
|
| 94 |
+
# Redirect command line output to silence deepspeed import output.
|
| 95 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 96 |
+
import deepspeed
|
| 97 |
+
deepspeed_version = deepspeed.__version__
|
| 98 |
+
|
| 99 |
+
info = {
|
| 100 |
+
"`transformers` version": __version__,
|
| 101 |
+
"Platform": platform.platform(),
|
| 102 |
+
"Python version": platform.python_version(),
|
| 103 |
+
"Huggingface_hub version": huggingface_hub.__version__,
|
| 104 |
+
"Safetensors version": f"{safetensors_version}",
|
| 105 |
+
"Accelerate version": f"{accelerate_version}",
|
| 106 |
+
"Accelerate config": f"{accelerate_config_str}",
|
| 107 |
+
"DeepSpeed version": f"{deepspeed_version}",
|
| 108 |
+
"PyTorch version (accelerator?)": f"{pt_version} ({pt_accelerator})",
|
| 109 |
+
"Using distributed or parallel set-up in script?": "<fill in>",
|
| 110 |
+
}
|
| 111 |
+
if is_torch_available():
|
| 112 |
+
if pt_cuda_available:
|
| 113 |
+
info["Using GPU in script?"] = "<fill in>"
|
| 114 |
+
info["GPU type"] = torch.cuda.get_device_name()
|
| 115 |
+
elif pt_xpu_available:
|
| 116 |
+
info["Using XPU in script?"] = "<fill in>"
|
| 117 |
+
info["XPU type"] = torch.xpu.get_device_name()
|
| 118 |
+
elif pt_hpu_available and hasattr(torch, "hpu"):
|
| 119 |
+
info["Using HPU in script?"] = "<fill in>"
|
| 120 |
+
info["HPU type"] = torch.hpu.get_device_name()
|
| 121 |
+
elif pt_npu_available and hasattr(torch, "npu"):
|
| 122 |
+
info["Using NPU in script?"] = "<fill in>"
|
| 123 |
+
info["NPU type"] = torch.npu.get_device_name()
|
| 124 |
+
if hasattr(torch.version, "cann"):
|
| 125 |
+
info["CANN version"] = torch.version.cann
|
| 126 |
+
|
| 127 |
+
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
| 128 |
+
print(_format_dict(info))
|
| 129 |
+
|
| 130 |
+
return info
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def version() -> None:
|
| 134 |
+
"""Print CLI version."""
|
| 135 |
+
print(__version__)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _format_dict(d: dict) -> str:
|
| 139 |
+
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/transformers.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Transformers CLI."""
|
| 15 |
+
|
| 16 |
+
from huggingface_hub import check_cli_update, typer_factory
|
| 17 |
+
|
| 18 |
+
from transformers.cli.add_new_model_like import add_new_model_like
|
| 19 |
+
from transformers.cli.chat import Chat
|
| 20 |
+
from transformers.cli.download import download
|
| 21 |
+
from transformers.cli.serve import Serve
|
| 22 |
+
from transformers.cli.system import env, version
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
app = typer_factory(help="Transformers CLI")
|
| 26 |
+
|
| 27 |
+
app.command()(add_new_model_like)
|
| 28 |
+
app.command(name="chat")(Chat)
|
| 29 |
+
app.command()(download)
|
| 30 |
+
app.command()(env)
|
| 31 |
+
app.command(name="serve")(Serve)
|
| 32 |
+
app.command()(version)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main():
|
| 36 |
+
check_cli_update("transformers")
|
| 37 |
+
app()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
main()
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/distributed/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ..utils import _LazyModule
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
_import_structure = {
|
| 21 |
+
"configuration_utils": ["DistributedConfig"],
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
from .configuration_utils import (
|
| 27 |
+
DistributedConfig,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
else:
|
| 31 |
+
import sys
|
| 32 |
+
|
| 33 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/distributed/configuration_utils.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import copy
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class DistributedConfig:
|
| 24 |
+
"""
|
| 25 |
+
Base class for distributed configs
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
enable_expert_parallel: bool = False
|
| 29 |
+
# TODO: add tp_plan, pp_plan, device_mesh etc..
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def from_dict(cls, config_dict, **kwargs):
|
| 33 |
+
"""
|
| 34 |
+
Constructs a DistributedConfig instance from a dictionary of parameters.
|
| 35 |
+
Args:
|
| 36 |
+
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
|
| 37 |
+
**kwargs: Additional keyword arguments to override dictionary values.
|
| 38 |
+
Returns:
|
| 39 |
+
DistributedConfig: Instance of DistributedConfig constructed from the dictionary.
|
| 40 |
+
"""
|
| 41 |
+
config = cls(**config_dict)
|
| 42 |
+
to_remove = []
|
| 43 |
+
for key, value in kwargs.items():
|
| 44 |
+
if hasattr(config, key):
|
| 45 |
+
setattr(config, key, value)
|
| 46 |
+
to_remove.append(key)
|
| 47 |
+
for key in to_remove:
|
| 48 |
+
kwargs.pop(key, None)
|
| 49 |
+
return config
|
| 50 |
+
|
| 51 |
+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
|
| 52 |
+
def to_json_file(self, json_file_path: str | os.PathLike):
|
| 53 |
+
"""
|
| 54 |
+
Save this instance to a JSON file.
|
| 55 |
+
Args:
|
| 56 |
+
json_file_path (`str` or `os.PathLike`):
|
| 57 |
+
Path to the JSON file in which this configuration instance's parameters will be saved.
|
| 58 |
+
use_diff (`bool`, *optional*, defaults to `True`):
|
| 59 |
+
If set to `True`, only the difference between the config instance and the default
|
| 60 |
+
`QuantizationConfig()` is serialized to JSON file.
|
| 61 |
+
"""
|
| 62 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
| 63 |
+
config_dict = self.to_dict()
|
| 64 |
+
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
| 65 |
+
|
| 66 |
+
writer.write(json_string)
|
| 67 |
+
|
| 68 |
+
def to_dict(self) -> dict[str, Any]:
|
| 69 |
+
"""
|
| 70 |
+
Serializes this instance to a Python dictionary. Returns:
|
| 71 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
| 72 |
+
"""
|
| 73 |
+
return copy.deepcopy(self.__dict__)
|
| 74 |
+
|
| 75 |
+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
|
| 76 |
+
def __iter__(self):
|
| 77 |
+
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
|
| 78 |
+
yield from copy.deepcopy(self.__dict__).items()
|
| 79 |
+
|
| 80 |
+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
|
| 81 |
+
def __repr__(self):
|
| 82 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
| 83 |
+
|
| 84 |
+
def to_json_string(self):
|
| 85 |
+
"""
|
| 86 |
+
Serializes this instance to a JSON formatted string.
|
| 87 |
+
Returns:
|
| 88 |
+
str: JSON formatted string representing the configuration instance.
|
| 89 |
+
"""
|
| 90 |
+
return json.dumps(self.__dict__, indent=2) + "\n"
|
| 91 |
+
|
| 92 |
+
def update(self, **kwargs):
|
| 93 |
+
"""
|
| 94 |
+
Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
|
| 95 |
+
returning all the unused kwargs.
|
| 96 |
+
Args:
|
| 97 |
+
kwargs (`Dict[str, Any]`):
|
| 98 |
+
Dictionary of attributes to tentatively update this class.
|
| 99 |
+
Returns:
|
| 100 |
+
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
|
| 101 |
+
"""
|
| 102 |
+
to_remove = []
|
| 103 |
+
for key, value in kwargs.items():
|
| 104 |
+
if hasattr(self, key):
|
| 105 |
+
setattr(self, key, value)
|
| 106 |
+
to_remove.append(key)
|
| 107 |
+
|
| 108 |
+
# Remove all the attributes that were updated, without modifying the input dict
|
| 109 |
+
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
| 110 |
+
return unused_kwargs
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/hyperparameter_search.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .integrations import (
|
| 16 |
+
is_optuna_available,
|
| 17 |
+
is_ray_tune_available,
|
| 18 |
+
is_wandb_available,
|
| 19 |
+
run_hp_search_optuna,
|
| 20 |
+
run_hp_search_ray,
|
| 21 |
+
run_hp_search_wandb,
|
| 22 |
+
)
|
| 23 |
+
from .trainer_utils import (
|
| 24 |
+
HPSearchBackend,
|
| 25 |
+
default_hp_space_optuna,
|
| 26 |
+
default_hp_space_ray,
|
| 27 |
+
default_hp_space_wandb,
|
| 28 |
+
)
|
| 29 |
+
from .utils import logging
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class HyperParamSearchBackendBase:
|
| 36 |
+
name: str
|
| 37 |
+
pip_package: str | None = None
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def is_available():
|
| 41 |
+
raise NotImplementedError
|
| 42 |
+
|
| 43 |
+
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
| 44 |
+
raise NotImplementedError
|
| 45 |
+
|
| 46 |
+
def default_hp_space(self, trial):
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
|
| 49 |
+
def ensure_available(self):
|
| 50 |
+
if not self.is_available():
|
| 51 |
+
raise RuntimeError(
|
| 52 |
+
f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
@classmethod
|
| 56 |
+
def pip_install(cls):
|
| 57 |
+
return f"`pip install {cls.pip_package or cls.name}`"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class OptunaBackend(HyperParamSearchBackendBase):
|
| 61 |
+
name = "optuna"
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def is_available():
|
| 65 |
+
return is_optuna_available()
|
| 66 |
+
|
| 67 |
+
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
| 68 |
+
return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
|
| 69 |
+
|
| 70 |
+
def default_hp_space(self, trial):
|
| 71 |
+
return default_hp_space_optuna(trial)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class RayTuneBackend(HyperParamSearchBackendBase):
|
| 75 |
+
name = "ray"
|
| 76 |
+
pip_package = "'ray[tune]'"
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def is_available():
|
| 80 |
+
return is_ray_tune_available()
|
| 81 |
+
|
| 82 |
+
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
| 83 |
+
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
|
| 84 |
+
|
| 85 |
+
def default_hp_space(self, trial):
|
| 86 |
+
return default_hp_space_ray(trial)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class WandbBackend(HyperParamSearchBackendBase):
|
| 90 |
+
name = "wandb"
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def is_available():
|
| 94 |
+
return is_wandb_available()
|
| 95 |
+
|
| 96 |
+
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
| 97 |
+
return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
|
| 98 |
+
|
| 99 |
+
def default_hp_space(self, trial):
|
| 100 |
+
return default_hp_space_wandb(trial)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
|
| 104 |
+
HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, WandbBackend]
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def default_hp_search_backend() -> str:
|
| 109 |
+
available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
|
| 110 |
+
if len(available_backends) > 0:
|
| 111 |
+
name = available_backends[0].name
|
| 112 |
+
if len(available_backends) > 1:
|
| 113 |
+
logger.info(
|
| 114 |
+
f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
|
| 115 |
+
)
|
| 116 |
+
return name
|
| 117 |
+
raise RuntimeError(
|
| 118 |
+
"No hyperparameter search backend available.\n"
|
| 119 |
+
+ "\n".join(
|
| 120 |
+
f" - To install {backend.name} run {backend.pip_install()}"
|
| 121 |
+
for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
|
| 122 |
+
)
|
| 123 |
+
)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_transforms.py
ADDED
|
@@ -0,0 +1,1073 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from collections import defaultdict
|
| 16 |
+
from collections.abc import Collection, Iterable
|
| 17 |
+
from math import ceil
|
| 18 |
+
from typing import Optional, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
from .image_utils import (
|
| 23 |
+
ChannelDimension,
|
| 24 |
+
ImageInput,
|
| 25 |
+
get_channel_dimension_axis,
|
| 26 |
+
get_image_size,
|
| 27 |
+
infer_channel_dimension_format,
|
| 28 |
+
)
|
| 29 |
+
from .utils import ExplicitEnum, TensorType, is_torch_tensor
|
| 30 |
+
from .utils.import_utils import (
|
| 31 |
+
is_torch_available,
|
| 32 |
+
is_vision_available,
|
| 33 |
+
requires_backends,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if is_vision_available():
|
| 38 |
+
import PIL
|
| 39 |
+
|
| 40 |
+
from .image_utils import PILImageResampling
|
| 41 |
+
|
| 42 |
+
if is_torch_available():
|
| 43 |
+
import torch
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def to_channel_dimension_format(
|
| 47 |
+
image: np.ndarray,
|
| 48 |
+
channel_dim: ChannelDimension | str,
|
| 49 |
+
input_channel_dim: ChannelDimension | str | None = None,
|
| 50 |
+
) -> np.ndarray:
|
| 51 |
+
"""
|
| 52 |
+
Converts `image` to the channel dimension format specified by `channel_dim`. The input
|
| 53 |
+
can have arbitrary number of leading dimensions. Only last three dimension will be permuted
|
| 54 |
+
to format the `image`.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
image (`numpy.ndarray`):
|
| 58 |
+
The image to have its channel dimension set.
|
| 59 |
+
channel_dim (`ChannelDimension`):
|
| 60 |
+
The channel dimension format to use.
|
| 61 |
+
input_channel_dim (`ChannelDimension`, *optional*):
|
| 62 |
+
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
`np.ndarray`: The image with the channel dimension set to `channel_dim`.
|
| 66 |
+
"""
|
| 67 |
+
if not isinstance(image, np.ndarray):
|
| 68 |
+
raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
|
| 69 |
+
|
| 70 |
+
if input_channel_dim is None:
|
| 71 |
+
input_channel_dim = infer_channel_dimension_format(image)
|
| 72 |
+
|
| 73 |
+
target_channel_dim = ChannelDimension(channel_dim)
|
| 74 |
+
if input_channel_dim == target_channel_dim:
|
| 75 |
+
return image
|
| 76 |
+
|
| 77 |
+
if target_channel_dim == ChannelDimension.FIRST:
|
| 78 |
+
axes = list(range(image.ndim - 3)) + [image.ndim - 1, image.ndim - 3, image.ndim - 2]
|
| 79 |
+
image = image.transpose(axes)
|
| 80 |
+
elif target_channel_dim == ChannelDimension.LAST:
|
| 81 |
+
axes = list(range(image.ndim - 3)) + [image.ndim - 2, image.ndim - 1, image.ndim - 3]
|
| 82 |
+
image = image.transpose(axes)
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f"Unsupported channel dimension format: {channel_dim}")
|
| 85 |
+
|
| 86 |
+
return image
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def rescale(
|
| 90 |
+
image: np.ndarray,
|
| 91 |
+
scale: float,
|
| 92 |
+
data_format: ChannelDimension | None = None,
|
| 93 |
+
dtype: np.dtype = np.float32,
|
| 94 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 95 |
+
) -> np.ndarray:
|
| 96 |
+
"""
|
| 97 |
+
Rescales `image` by `scale`.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
image (`np.ndarray`):
|
| 101 |
+
The image to rescale.
|
| 102 |
+
scale (`float`):
|
| 103 |
+
The scale to use for rescaling the image.
|
| 104 |
+
data_format (`ChannelDimension`, *optional*):
|
| 105 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
| 106 |
+
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
|
| 107 |
+
The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature
|
| 108 |
+
extractors.
|
| 109 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 110 |
+
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
`np.ndarray`: The rescaled image.
|
| 114 |
+
"""
|
| 115 |
+
if not isinstance(image, np.ndarray):
|
| 116 |
+
raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
|
| 117 |
+
|
| 118 |
+
rescaled_image = image.astype(np.float64) * scale # Numpy type promotion has changed, so always upcast first
|
| 119 |
+
if data_format is not None:
|
| 120 |
+
rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
|
| 121 |
+
|
| 122 |
+
rescaled_image = rescaled_image.astype(dtype) # Finally downcast to the desired dtype at the end
|
| 123 |
+
|
| 124 |
+
return rescaled_image
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _rescale_for_pil_conversion(image):
|
| 128 |
+
"""
|
| 129 |
+
Detects whether or not the image needs to be rescaled before being converted to a PIL image.
|
| 130 |
+
|
| 131 |
+
The assumption is that if the image is of type `np.float` and all values are between 0 and 1, it needs to be
|
| 132 |
+
rescaled.
|
| 133 |
+
"""
|
| 134 |
+
if image.dtype == np.uint8:
|
| 135 |
+
do_rescale = False
|
| 136 |
+
elif np.allclose(image, image.astype(int)):
|
| 137 |
+
if np.all(image >= 0) and np.all(image <= 255):
|
| 138 |
+
do_rescale = False
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
"The image to be converted to a PIL image contains values outside the range [0, 255], "
|
| 142 |
+
f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
|
| 143 |
+
)
|
| 144 |
+
elif np.all(image >= 0) and np.all(image <= 1):
|
| 145 |
+
do_rescale = True
|
| 146 |
+
else:
|
| 147 |
+
raise ValueError(
|
| 148 |
+
"The image to be converted to a PIL image contains values outside the range [0, 1], "
|
| 149 |
+
f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
|
| 150 |
+
)
|
| 151 |
+
return do_rescale
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def to_pil_image(
|
| 155 |
+
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor"],
|
| 156 |
+
do_rescale: bool | None = None,
|
| 157 |
+
image_mode: str | None = None,
|
| 158 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 159 |
+
) -> "PIL.Image.Image":
|
| 160 |
+
"""
|
| 161 |
+
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
| 162 |
+
needed.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
|
| 166 |
+
The image to convert to the `PIL.Image` format.
|
| 167 |
+
do_rescale (`bool`, *optional*):
|
| 168 |
+
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
|
| 169 |
+
to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
|
| 170 |
+
and `False` otherwise.
|
| 171 |
+
image_mode (`str`, *optional*):
|
| 172 |
+
The mode to use for the PIL image. If unset, will use the default mode for the input image type.
|
| 173 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 174 |
+
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
`PIL.Image.Image`: The converted image.
|
| 178 |
+
"""
|
| 179 |
+
requires_backends(to_pil_image, ["vision"])
|
| 180 |
+
|
| 181 |
+
if isinstance(image, PIL.Image.Image):
|
| 182 |
+
return image
|
| 183 |
+
|
| 184 |
+
# Convert all tensors to numpy arrays before converting to PIL image
|
| 185 |
+
if is_torch_tensor(image):
|
| 186 |
+
image = image.numpy()
|
| 187 |
+
elif not isinstance(image, np.ndarray):
|
| 188 |
+
raise ValueError(f"Input image type not supported: {type(image)}")
|
| 189 |
+
|
| 190 |
+
# If the channel has been moved to first dim, we put it back at the end.
|
| 191 |
+
image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
|
| 192 |
+
|
| 193 |
+
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
|
| 194 |
+
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
|
| 195 |
+
|
| 196 |
+
# PIL.Image can only store uint8 values so we rescale the image to be between 0 and 255 if needed.
|
| 197 |
+
do_rescale = _rescale_for_pil_conversion(image) if do_rescale is None else do_rescale
|
| 198 |
+
|
| 199 |
+
if do_rescale:
|
| 200 |
+
image = rescale(image, 255)
|
| 201 |
+
|
| 202 |
+
image = image.astype(np.uint8)
|
| 203 |
+
return PIL.Image.fromarray(image, mode=image_mode)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]:
|
| 207 |
+
"""
|
| 208 |
+
Computes the output image size given the input image size and the desired output size.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
image_size (`tuple[int, int]`):
|
| 212 |
+
The input image size.
|
| 213 |
+
size (`int`):
|
| 214 |
+
The desired output size.
|
| 215 |
+
max_size (`int`, *optional*):
|
| 216 |
+
The maximum allowed output size.
|
| 217 |
+
"""
|
| 218 |
+
height, width = image_size
|
| 219 |
+
raw_size = None
|
| 220 |
+
if max_size is not None:
|
| 221 |
+
min_original_size = float(min((height, width)))
|
| 222 |
+
max_original_size = float(max((height, width)))
|
| 223 |
+
if max_original_size / min_original_size * size > max_size:
|
| 224 |
+
raw_size = max_size * min_original_size / max_original_size
|
| 225 |
+
size = int(round(raw_size))
|
| 226 |
+
|
| 227 |
+
if (height <= width and height == size) or (width <= height and width == size):
|
| 228 |
+
oh, ow = height, width
|
| 229 |
+
elif width < height:
|
| 230 |
+
ow = size
|
| 231 |
+
if max_size is not None and raw_size is not None:
|
| 232 |
+
oh = int(raw_size * height / width)
|
| 233 |
+
else:
|
| 234 |
+
oh = int(size * height / width)
|
| 235 |
+
else:
|
| 236 |
+
oh = size
|
| 237 |
+
if max_size is not None and raw_size is not None:
|
| 238 |
+
ow = int(raw_size * width / height)
|
| 239 |
+
else:
|
| 240 |
+
ow = int(size * width / height)
|
| 241 |
+
|
| 242 |
+
return (oh, ow)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366
|
| 246 |
+
def get_resize_output_image_size(
|
| 247 |
+
input_image: np.ndarray,
|
| 248 |
+
size: int | tuple[int, int] | list[int] | tuple[int, ...],
|
| 249 |
+
default_to_square: bool = True,
|
| 250 |
+
max_size: int | None = None,
|
| 251 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 252 |
+
) -> tuple:
|
| 253 |
+
"""
|
| 254 |
+
Find the target (height, width) dimension of the output image after resizing given the input image and the desired
|
| 255 |
+
size.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
input_image (`np.ndarray`):
|
| 259 |
+
The image to resize.
|
| 260 |
+
size (`int` or `tuple[int, int]` or list[int] or `tuple[int]`):
|
| 261 |
+
The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be matched to
|
| 262 |
+
this.
|
| 263 |
+
|
| 264 |
+
If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
|
| 265 |
+
`size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to this
|
| 266 |
+
number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
|
| 267 |
+
default_to_square (`bool`, *optional*, defaults to `True`):
|
| 268 |
+
How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a square
|
| 269 |
+
(`size`,`size`). If set to `False`, will replicate
|
| 270 |
+
[`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
|
| 271 |
+
with support for resizing only the smallest edge and providing an optional `max_size`.
|
| 272 |
+
max_size (`int`, *optional*):
|
| 273 |
+
The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater
|
| 274 |
+
than `max_size` after being resized according to `size`, then the image is resized again so that the longer
|
| 275 |
+
edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
|
| 276 |
+
than `size`. Only used if `default_to_square` is `False`.
|
| 277 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 278 |
+
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
`tuple`: The target (height, width) dimension of the output image after resizing.
|
| 282 |
+
"""
|
| 283 |
+
if isinstance(size, (tuple, list)):
|
| 284 |
+
if len(size) == 2:
|
| 285 |
+
return tuple(size)
|
| 286 |
+
elif len(size) == 1:
|
| 287 |
+
# Perform same logic as if size was an int
|
| 288 |
+
size = size[0]
|
| 289 |
+
else:
|
| 290 |
+
raise ValueError("size must have 1 or 2 elements if it is a list or tuple")
|
| 291 |
+
|
| 292 |
+
if default_to_square:
|
| 293 |
+
return (size, size)
|
| 294 |
+
|
| 295 |
+
height, width = get_image_size(input_image, input_data_format)
|
| 296 |
+
short, long = (width, height) if width <= height else (height, width)
|
| 297 |
+
requested_new_short = size
|
| 298 |
+
|
| 299 |
+
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
|
| 300 |
+
|
| 301 |
+
if max_size is not None:
|
| 302 |
+
if max_size <= requested_new_short:
|
| 303 |
+
raise ValueError(
|
| 304 |
+
f"max_size = {max_size} must be strictly greater than the requested "
|
| 305 |
+
f"size for the smaller edge size = {size}"
|
| 306 |
+
)
|
| 307 |
+
if new_long > max_size:
|
| 308 |
+
new_short, new_long = int(max_size * new_short / new_long), max_size
|
| 309 |
+
|
| 310 |
+
return (new_long, new_short) if width <= height else (new_short, new_long)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def resize(
|
| 314 |
+
image: np.ndarray,
|
| 315 |
+
size: tuple[int, int],
|
| 316 |
+
resample: Optional["PILImageResampling"] = None,
|
| 317 |
+
reducing_gap: int | None = None,
|
| 318 |
+
data_format: ChannelDimension | None = None,
|
| 319 |
+
return_numpy: bool = True,
|
| 320 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 321 |
+
) -> np.ndarray:
|
| 322 |
+
"""
|
| 323 |
+
Resizes `image` to `(height, width)` specified by `size` using the PIL library.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
image (`np.ndarray`):
|
| 327 |
+
The image to resize.
|
| 328 |
+
size (`tuple[int, int]`):
|
| 329 |
+
The size to use for resizing the image.
|
| 330 |
+
resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 331 |
+
The filter to user for resampling.
|
| 332 |
+
reducing_gap (`int`, *optional*):
|
| 333 |
+
Apply optimization by resizing the image in two steps. The bigger `reducing_gap`, the closer the result to
|
| 334 |
+
the fair resampling. See corresponding Pillow documentation for more details.
|
| 335 |
+
data_format (`ChannelDimension`, *optional*):
|
| 336 |
+
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
| 337 |
+
return_numpy (`bool`, *optional*, defaults to `True`):
|
| 338 |
+
Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
|
| 339 |
+
returned.
|
| 340 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 341 |
+
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
`np.ndarray`: The resized image.
|
| 345 |
+
"""
|
| 346 |
+
requires_backends(resize, ["vision"])
|
| 347 |
+
|
| 348 |
+
resample = resample if resample is not None else PILImageResampling.BILINEAR
|
| 349 |
+
|
| 350 |
+
if not len(size) == 2:
|
| 351 |
+
raise ValueError("size must have 2 elements")
|
| 352 |
+
|
| 353 |
+
# For all transformations, we want to keep the same data format as the input image unless otherwise specified.
|
| 354 |
+
# The resized image from PIL will always have channels last, so find the input format first.
|
| 355 |
+
if input_data_format is None:
|
| 356 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 357 |
+
data_format = input_data_format if data_format is None else data_format
|
| 358 |
+
|
| 359 |
+
# To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
|
| 360 |
+
# the pillow library to resize the image and then convert back to numpy
|
| 361 |
+
do_rescale = False
|
| 362 |
+
if not isinstance(image, PIL.Image.Image):
|
| 363 |
+
do_rescale = _rescale_for_pil_conversion(image)
|
| 364 |
+
image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
|
| 365 |
+
height, width = size
|
| 366 |
+
# PIL images are in the format (width, height)
|
| 367 |
+
resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
|
| 368 |
+
|
| 369 |
+
if return_numpy:
|
| 370 |
+
resized_image = np.array(resized_image)
|
| 371 |
+
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
|
| 372 |
+
# so we need to add it back if necessary.
|
| 373 |
+
resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
|
| 374 |
+
# The image is always in channels last format after converting from a PIL image
|
| 375 |
+
resized_image = to_channel_dimension_format(
|
| 376 |
+
resized_image, data_format, input_channel_dim=ChannelDimension.LAST
|
| 377 |
+
)
|
| 378 |
+
# If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
|
| 379 |
+
# rescale it back to the original range.
|
| 380 |
+
resized_image = rescale(resized_image, 1 / 255) if do_rescale else resized_image
|
| 381 |
+
return resized_image
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def normalize(
|
| 385 |
+
image: np.ndarray,
|
| 386 |
+
mean: float | Collection[float],
|
| 387 |
+
std: float | Collection[float],
|
| 388 |
+
data_format: ChannelDimension | None = None,
|
| 389 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 390 |
+
) -> np.ndarray:
|
| 391 |
+
"""
|
| 392 |
+
Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
|
| 393 |
+
|
| 394 |
+
image = (image - mean) / std
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
image (`np.ndarray`):
|
| 398 |
+
The image to normalize.
|
| 399 |
+
mean (`float` or `Collection[float]`):
|
| 400 |
+
The mean to use for normalization.
|
| 401 |
+
std (`float` or `Collection[float]`):
|
| 402 |
+
The standard deviation to use for normalization.
|
| 403 |
+
data_format (`ChannelDimension`, *optional*):
|
| 404 |
+
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
| 405 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 406 |
+
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
| 407 |
+
"""
|
| 408 |
+
if not isinstance(image, np.ndarray):
|
| 409 |
+
raise TypeError("image must be a numpy array")
|
| 410 |
+
|
| 411 |
+
if input_data_format is None:
|
| 412 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 413 |
+
|
| 414 |
+
channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
|
| 415 |
+
num_channels = image.shape[channel_axis]
|
| 416 |
+
|
| 417 |
+
# We cast to float32 to avoid errors that can occur when subtracting uint8 values.
|
| 418 |
+
# We preserve the original dtype if it is a float type to prevent upcasting float16.
|
| 419 |
+
if not np.issubdtype(image.dtype, np.floating):
|
| 420 |
+
image = image.astype(np.float32)
|
| 421 |
+
|
| 422 |
+
if isinstance(mean, Collection):
|
| 423 |
+
if len(mean) != num_channels:
|
| 424 |
+
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
|
| 425 |
+
else:
|
| 426 |
+
mean = [mean] * num_channels
|
| 427 |
+
mean = np.array(mean, dtype=image.dtype)
|
| 428 |
+
|
| 429 |
+
if isinstance(std, Collection):
|
| 430 |
+
if len(std) != num_channels:
|
| 431 |
+
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
|
| 432 |
+
else:
|
| 433 |
+
std = [std] * num_channels
|
| 434 |
+
std = np.array(std, dtype=image.dtype)
|
| 435 |
+
|
| 436 |
+
if input_data_format == ChannelDimension.LAST:
|
| 437 |
+
image = (image - mean) / std
|
| 438 |
+
else:
|
| 439 |
+
image = ((image.T - mean) / std).T
|
| 440 |
+
|
| 441 |
+
image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
| 442 |
+
return image
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def center_crop(
|
| 446 |
+
image: np.ndarray,
|
| 447 |
+
size: tuple[int, int],
|
| 448 |
+
data_format: str | ChannelDimension | None = None,
|
| 449 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 450 |
+
) -> np.ndarray:
|
| 451 |
+
"""
|
| 452 |
+
Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
|
| 453 |
+
the size given, it will be padded (so the returned result will always be of size `size`).
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
image (`np.ndarray`):
|
| 457 |
+
The image to crop.
|
| 458 |
+
size (`tuple[int, int]`):
|
| 459 |
+
The target size for the cropped image.
|
| 460 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 461 |
+
The channel dimension format for the output image. Can be one of:
|
| 462 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 463 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 464 |
+
If unset, will use the inferred format of the input image.
|
| 465 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 466 |
+
The channel dimension format for the input image. Can be one of:
|
| 467 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 468 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 469 |
+
If unset, will use the inferred format of the input image.
|
| 470 |
+
Returns:
|
| 471 |
+
`np.ndarray`: The cropped image.
|
| 472 |
+
"""
|
| 473 |
+
requires_backends(center_crop, ["vision"])
|
| 474 |
+
|
| 475 |
+
if not isinstance(image, np.ndarray):
|
| 476 |
+
raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
|
| 477 |
+
|
| 478 |
+
if not isinstance(size, Iterable) or len(size) != 2:
|
| 479 |
+
raise ValueError("size must have 2 elements representing the height and width of the output image")
|
| 480 |
+
|
| 481 |
+
if input_data_format is None:
|
| 482 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 483 |
+
output_data_format = data_format if data_format is not None else input_data_format
|
| 484 |
+
|
| 485 |
+
# We perform the crop in (C, H, W) format and then convert to the output format
|
| 486 |
+
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
|
| 487 |
+
|
| 488 |
+
orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
|
| 489 |
+
crop_height, crop_width = size
|
| 490 |
+
crop_height, crop_width = int(crop_height), int(crop_width)
|
| 491 |
+
|
| 492 |
+
# In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
|
| 493 |
+
top = (orig_height - crop_height) // 2
|
| 494 |
+
bottom = top + crop_height
|
| 495 |
+
# In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
|
| 496 |
+
left = (orig_width - crop_width) // 2
|
| 497 |
+
right = left + crop_width
|
| 498 |
+
|
| 499 |
+
# Check if cropped area is within image boundaries
|
| 500 |
+
if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
|
| 501 |
+
image = image[..., top:bottom, left:right]
|
| 502 |
+
image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
|
| 503 |
+
return image
|
| 504 |
+
|
| 505 |
+
# Otherwise, we may need to pad if the image is too small. Oh joy...
|
| 506 |
+
new_height = max(crop_height, orig_height)
|
| 507 |
+
new_width = max(crop_width, orig_width)
|
| 508 |
+
new_shape = image.shape[:-2] + (new_height, new_width)
|
| 509 |
+
new_image = np.zeros_like(image, shape=new_shape)
|
| 510 |
+
|
| 511 |
+
# If the image is too small, pad it with zeros
|
| 512 |
+
top_pad = ceil((new_height - orig_height) / 2)
|
| 513 |
+
bottom_pad = top_pad + orig_height
|
| 514 |
+
left_pad = ceil((new_width - orig_width) / 2)
|
| 515 |
+
right_pad = left_pad + orig_width
|
| 516 |
+
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
|
| 517 |
+
|
| 518 |
+
top += top_pad
|
| 519 |
+
bottom += top_pad
|
| 520 |
+
left += left_pad
|
| 521 |
+
right += left_pad
|
| 522 |
+
|
| 523 |
+
new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
|
| 524 |
+
new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
|
| 525 |
+
|
| 526 |
+
return new_image
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor":
|
| 530 |
+
center_x, center_y, width, height = bboxes_center.unbind(-1)
|
| 531 |
+
bbox_corners = torch.stack(
|
| 532 |
+
# top left x, top left y, bottom right x, bottom right y
|
| 533 |
+
[(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)],
|
| 534 |
+
dim=-1,
|
| 535 |
+
)
|
| 536 |
+
return bbox_corners
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray:
|
| 540 |
+
center_x, center_y, width, height = bboxes_center.T
|
| 541 |
+
bboxes_corners = np.stack(
|
| 542 |
+
# top left x, top left y, bottom right x, bottom right y
|
| 543 |
+
[center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
|
| 544 |
+
axis=-1,
|
| 545 |
+
)
|
| 546 |
+
return bboxes_corners
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
# 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
|
| 550 |
+
def center_to_corners_format(bboxes_center: TensorType) -> TensorType:
|
| 551 |
+
"""
|
| 552 |
+
Converts bounding boxes from center format to corners format.
|
| 553 |
+
|
| 554 |
+
center format: contains the coordinate for the center of the box and its width, height dimensions
|
| 555 |
+
(center_x, center_y, width, height)
|
| 556 |
+
corners format: contains the coordinates for the top-left and bottom-right corners of the box
|
| 557 |
+
(top_left_x, top_left_y, bottom_right_x, bottom_right_y)
|
| 558 |
+
"""
|
| 559 |
+
# Function is used during model forward pass, so we use torch if relevant, without converting to numpy
|
| 560 |
+
if is_torch_tensor(bboxes_center):
|
| 561 |
+
return _center_to_corners_format_torch(bboxes_center)
|
| 562 |
+
elif isinstance(bboxes_center, np.ndarray):
|
| 563 |
+
return _center_to_corners_format_numpy(bboxes_center)
|
| 564 |
+
|
| 565 |
+
raise ValueError(f"Unsupported input type {type(bboxes_center)}")
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor":
|
| 569 |
+
top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1)
|
| 570 |
+
b = [
|
| 571 |
+
(top_left_x + bottom_right_x) / 2, # center x
|
| 572 |
+
(top_left_y + bottom_right_y) / 2, # center y
|
| 573 |
+
(bottom_right_x - top_left_x), # width
|
| 574 |
+
(bottom_right_y - top_left_y), # height
|
| 575 |
+
]
|
| 576 |
+
return torch.stack(b, dim=-1)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray:
|
| 580 |
+
top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T
|
| 581 |
+
bboxes_center = np.stack(
|
| 582 |
+
[
|
| 583 |
+
(top_left_x + bottom_right_x) / 2, # center x
|
| 584 |
+
(top_left_y + bottom_right_y) / 2, # center y
|
| 585 |
+
(bottom_right_x - top_left_x), # width
|
| 586 |
+
(bottom_right_y - top_left_y), # height
|
| 587 |
+
],
|
| 588 |
+
axis=-1,
|
| 589 |
+
)
|
| 590 |
+
return bboxes_center
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def corners_to_center_format(bboxes_corners: TensorType) -> TensorType:
|
| 594 |
+
"""
|
| 595 |
+
Converts bounding boxes from corners format to center format.
|
| 596 |
+
|
| 597 |
+
corners format: contains the coordinates for the top-left and bottom-right corners of the box
|
| 598 |
+
(top_left_x, top_left_y, bottom_right_x, bottom_right_y)
|
| 599 |
+
center format: contains the coordinate for the center of the box and its the width, height dimensions
|
| 600 |
+
(center_x, center_y, width, height)
|
| 601 |
+
"""
|
| 602 |
+
# Inverse function accepts different input types so implemented here too
|
| 603 |
+
if is_torch_tensor(bboxes_corners):
|
| 604 |
+
return _corners_to_center_format_torch(bboxes_corners)
|
| 605 |
+
elif isinstance(bboxes_corners, np.ndarray):
|
| 606 |
+
return _corners_to_center_format_numpy(bboxes_corners)
|
| 607 |
+
|
| 608 |
+
raise ValueError(f"Unsupported input type {type(bboxes_corners)}")
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def safe_squeeze(
|
| 612 |
+
tensor: Union[np.ndarray, "torch.Tensor"], axis: int | None = None
|
| 613 |
+
) -> Union[np.ndarray, "torch.Tensor"]:
|
| 614 |
+
"""
|
| 615 |
+
Squeezes a tensor, but only if the axis specified has dim 1.
|
| 616 |
+
"""
|
| 617 |
+
if axis is None:
|
| 618 |
+
return tensor.squeeze()
|
| 619 |
+
|
| 620 |
+
try:
|
| 621 |
+
return tensor.squeeze(axis=axis)
|
| 622 |
+
except ValueError:
|
| 623 |
+
return tensor
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
# 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
|
| 627 |
+
# Copyright (c) 2018, Alexander Kirillov
|
| 628 |
+
# All rights reserved.
|
| 629 |
+
def rgb_to_id(color):
|
| 630 |
+
"""
|
| 631 |
+
Converts RGB color to unique ID.
|
| 632 |
+
"""
|
| 633 |
+
if isinstance(color, np.ndarray) and len(color.shape) == 3:
|
| 634 |
+
if color.dtype == np.uint8:
|
| 635 |
+
color = color.astype(np.int32)
|
| 636 |
+
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
|
| 637 |
+
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def id_to_rgb(id_map):
|
| 641 |
+
"""
|
| 642 |
+
Converts unique ID to RGB color.
|
| 643 |
+
"""
|
| 644 |
+
if isinstance(id_map, np.ndarray):
|
| 645 |
+
id_map_copy = id_map.copy()
|
| 646 |
+
rgb_shape = tuple(list(id_map.shape) + [3])
|
| 647 |
+
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
|
| 648 |
+
for i in range(3):
|
| 649 |
+
rgb_map[..., i] = id_map_copy % 256
|
| 650 |
+
id_map_copy //= 256
|
| 651 |
+
return rgb_map
|
| 652 |
+
color = []
|
| 653 |
+
for _ in range(3):
|
| 654 |
+
color.append(id_map % 256)
|
| 655 |
+
id_map //= 256
|
| 656 |
+
return color
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
class PaddingMode(ExplicitEnum):
|
| 660 |
+
"""
|
| 661 |
+
Enum class for the different padding modes to use when padding images.
|
| 662 |
+
"""
|
| 663 |
+
|
| 664 |
+
CONSTANT = "constant"
|
| 665 |
+
REFLECT = "reflect"
|
| 666 |
+
REPLICATE = "replicate"
|
| 667 |
+
SYMMETRIC = "symmetric"
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
def pad(
|
| 671 |
+
image: np.ndarray,
|
| 672 |
+
padding: int | tuple[int, int] | Iterable[tuple[int, int]],
|
| 673 |
+
mode: PaddingMode = PaddingMode.CONSTANT,
|
| 674 |
+
constant_values: float | Iterable[float] = 0.0,
|
| 675 |
+
data_format: str | ChannelDimension | None = None,
|
| 676 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 677 |
+
) -> np.ndarray:
|
| 678 |
+
"""
|
| 679 |
+
Pads the `image` with the specified (height, width) `padding` and `mode`.
|
| 680 |
+
|
| 681 |
+
Args:
|
| 682 |
+
image (`np.ndarray`):
|
| 683 |
+
The image to pad.
|
| 684 |
+
padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
|
| 685 |
+
Padding to apply to the edges of the height, width axes. Can be one of three formats:
|
| 686 |
+
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
|
| 687 |
+
- `((before, after),)` yields same before and after pad for height and width.
|
| 688 |
+
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
|
| 689 |
+
mode (`PaddingMode`):
|
| 690 |
+
The padding mode to use. Can be one of:
|
| 691 |
+
- `"constant"`: pads with a constant value.
|
| 692 |
+
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
|
| 693 |
+
vector along each axis.
|
| 694 |
+
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
|
| 695 |
+
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
|
| 696 |
+
constant_values (`float` or `Iterable[float]`, *optional*):
|
| 697 |
+
The value to use for the padding if `mode` is `"constant"`.
|
| 698 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 699 |
+
The channel dimension format for the output image. Can be one of:
|
| 700 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 701 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 702 |
+
If unset, will use same as the input image.
|
| 703 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 704 |
+
The channel dimension format for the input image. Can be one of:
|
| 705 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 706 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 707 |
+
If unset, will use the inferred format of the input image.
|
| 708 |
+
|
| 709 |
+
Returns:
|
| 710 |
+
`np.ndarray`: The padded image.
|
| 711 |
+
|
| 712 |
+
"""
|
| 713 |
+
if input_data_format is None:
|
| 714 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 715 |
+
|
| 716 |
+
def _expand_for_data_format(values):
|
| 717 |
+
"""
|
| 718 |
+
Convert values to be in the format expected by np.pad based on the data format.
|
| 719 |
+
"""
|
| 720 |
+
if isinstance(values, (int, float)):
|
| 721 |
+
values = ((values, values), (values, values))
|
| 722 |
+
elif isinstance(values, tuple) and len(values) == 1:
|
| 723 |
+
values = ((values[0], values[0]), (values[0], values[0]))
|
| 724 |
+
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
|
| 725 |
+
values = (values, values)
|
| 726 |
+
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
|
| 727 |
+
pass
|
| 728 |
+
else:
|
| 729 |
+
raise ValueError(f"Unsupported format: {values}")
|
| 730 |
+
|
| 731 |
+
# add 0 for channel dimension
|
| 732 |
+
values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0))
|
| 733 |
+
|
| 734 |
+
# Add additional padding if there's a batch dimension
|
| 735 |
+
values = ((0, 0), *values) if image.ndim == 4 else values
|
| 736 |
+
return values
|
| 737 |
+
|
| 738 |
+
padding = _expand_for_data_format(padding)
|
| 739 |
+
|
| 740 |
+
if mode == PaddingMode.CONSTANT:
|
| 741 |
+
constant_values = _expand_for_data_format(constant_values)
|
| 742 |
+
image = np.pad(image, padding, mode="constant", constant_values=constant_values)
|
| 743 |
+
elif mode == PaddingMode.REFLECT:
|
| 744 |
+
image = np.pad(image, padding, mode="reflect")
|
| 745 |
+
elif mode == PaddingMode.REPLICATE:
|
| 746 |
+
image = np.pad(image, padding, mode="edge")
|
| 747 |
+
elif mode == PaddingMode.SYMMETRIC:
|
| 748 |
+
image = np.pad(image, padding, mode="symmetric")
|
| 749 |
+
else:
|
| 750 |
+
raise ValueError(f"Invalid padding mode: {mode}")
|
| 751 |
+
|
| 752 |
+
image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
| 753 |
+
return image
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
# TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default
|
| 757 |
+
def convert_to_rgb(image: ImageInput) -> ImageInput:
|
| 758 |
+
"""
|
| 759 |
+
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
|
| 760 |
+
as is.
|
| 761 |
+
Args:
|
| 762 |
+
image (Image):
|
| 763 |
+
The image to convert.
|
| 764 |
+
"""
|
| 765 |
+
requires_backends(convert_to_rgb, ["vision"])
|
| 766 |
+
|
| 767 |
+
if not isinstance(image, PIL.Image.Image):
|
| 768 |
+
return image
|
| 769 |
+
|
| 770 |
+
if image.mode == "RGB":
|
| 771 |
+
return image
|
| 772 |
+
|
| 773 |
+
image = image.convert("RGB")
|
| 774 |
+
return image
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
def flip_channel_order(
|
| 778 |
+
image: np.ndarray,
|
| 779 |
+
data_format: ChannelDimension | None = None,
|
| 780 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 781 |
+
) -> np.ndarray:
|
| 782 |
+
"""
|
| 783 |
+
Flips the channel order of the image.
|
| 784 |
+
|
| 785 |
+
If the image is in RGB format, it will be converted to BGR and vice versa.
|
| 786 |
+
|
| 787 |
+
Args:
|
| 788 |
+
image (`np.ndarray`):
|
| 789 |
+
The image to flip.
|
| 790 |
+
data_format (`ChannelDimension`, *optional*):
|
| 791 |
+
The channel dimension format for the output image. Can be one of:
|
| 792 |
+
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 793 |
+
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 794 |
+
If unset, will use same as the input image.
|
| 795 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 796 |
+
The channel dimension format for the input image. Can be one of:
|
| 797 |
+
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 798 |
+
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 799 |
+
If unset, will use the inferred format of the input image.
|
| 800 |
+
"""
|
| 801 |
+
input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format
|
| 802 |
+
|
| 803 |
+
if input_data_format == ChannelDimension.LAST:
|
| 804 |
+
image = image[..., ::-1]
|
| 805 |
+
elif input_data_format == ChannelDimension.FIRST:
|
| 806 |
+
image = image[::-1, ...]
|
| 807 |
+
else:
|
| 808 |
+
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
|
| 809 |
+
|
| 810 |
+
if data_format is not None:
|
| 811 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 812 |
+
return image
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
def split_to_tiles(images: "torch.Tensor", num_tiles_height: int, num_tiles_width: int) -> "torch.Tensor":
|
| 816 |
+
# Split image into number of required tiles (width x height)
|
| 817 |
+
batch_size, num_channels, height, width = images.size()
|
| 818 |
+
images = images.view(
|
| 819 |
+
batch_size,
|
| 820 |
+
num_channels,
|
| 821 |
+
num_tiles_height,
|
| 822 |
+
height // num_tiles_height,
|
| 823 |
+
num_tiles_width,
|
| 824 |
+
width // num_tiles_width,
|
| 825 |
+
)
|
| 826 |
+
# Permute dimensions to reorder the axes
|
| 827 |
+
image = images.permute(0, 2, 4, 1, 3, 5).contiguous()
|
| 828 |
+
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
| 829 |
+
image = image.view(
|
| 830 |
+
batch_size,
|
| 831 |
+
num_tiles_width * num_tiles_height,
|
| 832 |
+
num_channels,
|
| 833 |
+
height // num_tiles_height,
|
| 834 |
+
width // num_tiles_width,
|
| 835 |
+
)
|
| 836 |
+
return image
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
def divide_to_patches(
|
| 840 |
+
image: Union[np.ndarray, "torch.Tensor"], patch_size: int | tuple[int, int]
|
| 841 |
+
) -> list[Union[np.ndarray, "torch.Tensor"]]:
|
| 842 |
+
"""
|
| 843 |
+
Divides an image into patches of a specified size.
|
| 844 |
+
|
| 845 |
+
Args:
|
| 846 |
+
image (`np.array | "torch.Tensor"`):
|
| 847 |
+
The input image.
|
| 848 |
+
patch_size (`int` or `tuple[int, int]`):
|
| 849 |
+
The size of each patch. If an int, patches are square. If a tuple,
|
| 850 |
+
it is interpreted as `(patch_height, patch_width)`.
|
| 851 |
+
Returns:
|
| 852 |
+
list: A list of `np.array | "torch.Tensor"` representing the patches.
|
| 853 |
+
"""
|
| 854 |
+
patch_h, patch_w = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
|
| 855 |
+
patches = []
|
| 856 |
+
height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
| 857 |
+
for i in range(0, height, patch_h):
|
| 858 |
+
for j in range(0, width, patch_w):
|
| 859 |
+
patch = image[..., i : i + patch_h, j : j + patch_w]
|
| 860 |
+
patches.append(patch)
|
| 861 |
+
|
| 862 |
+
return patches
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
def _group_images_by_shape(nested_images, *paired_inputs, is_nested: bool = False):
|
| 866 |
+
"""
|
| 867 |
+
Helper function to flatten a single level of nested image and batch structures and group by shape.
|
| 868 |
+
Args:
|
| 869 |
+
nested_images (list):
|
| 870 |
+
A list of images or a single tensor
|
| 871 |
+
paired_inputs (Any, *optional*):
|
| 872 |
+
Zero or more lists that mirror the structure of `nested_images` (flat list, or list of lists when
|
| 873 |
+
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
|
| 874 |
+
same shape key. These paired values are grouped alongside `nested_images` but are not stacked in the output, so
|
| 875 |
+
they do not need to be tensors.
|
| 876 |
+
is_nested (bool, *optional*, defaults to False):
|
| 877 |
+
Whether the images are nested.
|
| 878 |
+
Returns:
|
| 879 |
+
tuple[dict, ...]:
|
| 880 |
+
- A dictionary with shape as key and list of images with that shape as value
|
| 881 |
+
- A dictionary with shape as key and list of paired values with that shape as value
|
| 882 |
+
- A dictionary mapping original indices to (shape, index) tuples
|
| 883 |
+
- A dictionary mapping original indices to (shape, index) tuples for each paired input
|
| 884 |
+
"""
|
| 885 |
+
grouped_images = defaultdict(list)
|
| 886 |
+
grouped_images_index = {}
|
| 887 |
+
paired_grouped_values = [defaultdict(list) for _ in paired_inputs]
|
| 888 |
+
|
| 889 |
+
# Normalize inputs to consistent nested structure
|
| 890 |
+
normalized_images = [nested_images] if not is_nested else nested_images
|
| 891 |
+
normalized_paired = []
|
| 892 |
+
for paired_input in paired_inputs:
|
| 893 |
+
normalized_paired.append([paired_input] if not is_nested else paired_input)
|
| 894 |
+
|
| 895 |
+
# Process each image and group by shape
|
| 896 |
+
for i, (sublist, *paired_sublists) in enumerate(zip(normalized_images, *normalized_paired)):
|
| 897 |
+
for j, (image, *paired_values) in enumerate(zip(sublist, *paired_sublists)):
|
| 898 |
+
key = (i, j) if is_nested else j
|
| 899 |
+
shape = image.shape[1:]
|
| 900 |
+
|
| 901 |
+
# Add to grouped structures
|
| 902 |
+
grouped_images[shape].append(image)
|
| 903 |
+
for paired_index, paired_value in enumerate(paired_values):
|
| 904 |
+
paired_grouped_values[paired_index][shape].append(paired_value)
|
| 905 |
+
grouped_images_index[key] = (shape, len(grouped_images[shape]) - 1)
|
| 906 |
+
|
| 907 |
+
# Store structure size for nested inputs to handle empty sublists during reconstruction
|
| 908 |
+
if is_nested:
|
| 909 |
+
grouped_images_index["_num_sublists"] = len(normalized_images)
|
| 910 |
+
|
| 911 |
+
return grouped_images, *paired_grouped_values, grouped_images_index
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
def _reconstruct_nested_structure(indices, processed_images):
|
| 915 |
+
"""Helper function to reconstruct a single level nested structure."""
|
| 916 |
+
# Get the number of sublists (handles empty sublists like in [[], [image]])
|
| 917 |
+
num_sublists = indices.pop("_num_sublists", None)
|
| 918 |
+
|
| 919 |
+
# Group indices by outer index
|
| 920 |
+
nested_indices = defaultdict(list)
|
| 921 |
+
for i, j in indices:
|
| 922 |
+
nested_indices[i].append(j)
|
| 923 |
+
|
| 924 |
+
# Determine the number of outer sublists
|
| 925 |
+
if num_sublists is not None:
|
| 926 |
+
max_outer_idx = num_sublists - 1
|
| 927 |
+
elif nested_indices:
|
| 928 |
+
max_outer_idx = max(nested_indices.keys())
|
| 929 |
+
else:
|
| 930 |
+
return []
|
| 931 |
+
|
| 932 |
+
# Create the result structure
|
| 933 |
+
result = []
|
| 934 |
+
for i in range(max_outer_idx + 1):
|
| 935 |
+
if i not in nested_indices:
|
| 936 |
+
result.append([])
|
| 937 |
+
else:
|
| 938 |
+
inner_max_idx = max(nested_indices[i])
|
| 939 |
+
inner_list = [None] * (inner_max_idx + 1)
|
| 940 |
+
for j in nested_indices[i]:
|
| 941 |
+
shape, idx = indices[(i, j)]
|
| 942 |
+
inner_list[j] = processed_images[shape][idx]
|
| 943 |
+
result.append(inner_list)
|
| 944 |
+
|
| 945 |
+
return result
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
def _iterate_items(items, is_nested: bool):
|
| 949 |
+
"""
|
| 950 |
+
Helper function to iterate over items yielding (key, item) pairs.
|
| 951 |
+
|
| 952 |
+
For nested structures, yields ((row_index, col_index), item).
|
| 953 |
+
For flat structures, yields (index, item).
|
| 954 |
+
"""
|
| 955 |
+
if is_nested:
|
| 956 |
+
for i, row in enumerate(items):
|
| 957 |
+
for j, item in enumerate(row):
|
| 958 |
+
yield (i, j), item
|
| 959 |
+
else:
|
| 960 |
+
for i, item in enumerate(items):
|
| 961 |
+
yield i, item
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
def _get_device_from_images(images, is_nested: bool) -> "torch.device":
|
| 965 |
+
"""
|
| 966 |
+
Get the device from the first non-empty element in a (potentially nested) list of images.
|
| 967 |
+
|
| 968 |
+
Handles cases like `images = [[], [image]]` where the first sublist may be empty.
|
| 969 |
+
"""
|
| 970 |
+
if is_nested:
|
| 971 |
+
for row in images:
|
| 972 |
+
if isinstance(row, torch.Tensor):
|
| 973 |
+
return row.device
|
| 974 |
+
if isinstance(row, list) and len(row) > 0:
|
| 975 |
+
return row[0].device
|
| 976 |
+
return images[0].device
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
def group_images_by_shape(
|
| 980 |
+
images: Union[list["torch.Tensor"], "torch.Tensor"],
|
| 981 |
+
*paired_inputs,
|
| 982 |
+
disable_grouping: bool | None,
|
| 983 |
+
is_nested: bool = False,
|
| 984 |
+
) -> tuple[dict, ...]:
|
| 985 |
+
"""
|
| 986 |
+
Groups images by shape.
|
| 987 |
+
Returns a dictionary with the shape as key and a list of images with that shape as value,
|
| 988 |
+
and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value.
|
| 989 |
+
|
| 990 |
+
The function supports both flat lists of tensors and nested structures.
|
| 991 |
+
The input must be either all flat or all nested, not a mix of both.
|
| 992 |
+
|
| 993 |
+
Args:
|
| 994 |
+
images (Union[list["torch.Tensor"], "torch.Tensor"]):
|
| 995 |
+
A list of images or a single tensor
|
| 996 |
+
paired_inputs (Any, *optional*):
|
| 997 |
+
Zero or more lists that mirror the structure of `images` (flat list, or list of lists when
|
| 998 |
+
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
|
| 999 |
+
same shape key. These paired values are grouped alongside `images` but are not stacked in the output, so
|
| 1000 |
+
they do not need to be tensors.
|
| 1001 |
+
disable_grouping (bool):
|
| 1002 |
+
Whether to disable grouping. If None, will be set to True if the images are on CPU, and False otherwise.
|
| 1003 |
+
This choice is based on empirical observations, as detailed here: https://github.com/huggingface/transformers/pull/38157
|
| 1004 |
+
is_nested (bool, *optional*, defaults to False):
|
| 1005 |
+
Whether the images are nested.
|
| 1006 |
+
|
| 1007 |
+
Returns:
|
| 1008 |
+
tuple[dict, ...]:
|
| 1009 |
+
- A dictionary with shape as key and list/batch of images with that shape as value
|
| 1010 |
+
- Zero or more dictionaries (one per argument in `*paired_inputs`) grouped consistently with `images`; these carry
|
| 1011 |
+
the corresponding per-item values and are not stacked
|
| 1012 |
+
- A dictionary mapping original indices to (shape, index) tuples
|
| 1013 |
+
"""
|
| 1014 |
+
# If disable grouping is not explicitly provided, we favor disabling it if the images are on CPU, and enabling it otherwise.
|
| 1015 |
+
if disable_grouping is None:
|
| 1016 |
+
device = _get_device_from_images(images, is_nested)
|
| 1017 |
+
disable_grouping = device == "cpu"
|
| 1018 |
+
|
| 1019 |
+
if disable_grouping:
|
| 1020 |
+
grouped_images_index = {key: (key, 0) for key, _ in _iterate_items(images, is_nested)}
|
| 1021 |
+
if is_nested:
|
| 1022 |
+
grouped_images_index["_num_sublists"] = len(images)
|
| 1023 |
+
|
| 1024 |
+
return (
|
| 1025 |
+
{key: img.unsqueeze(0) for key, img in _iterate_items(images, is_nested)},
|
| 1026 |
+
*[
|
| 1027 |
+
{key: item.unsqueeze(0) for key, item in _iterate_items(paired_list, is_nested)}
|
| 1028 |
+
for paired_list in paired_inputs
|
| 1029 |
+
],
|
| 1030 |
+
grouped_images_index,
|
| 1031 |
+
)
|
| 1032 |
+
|
| 1033 |
+
# Handle single level nested structure
|
| 1034 |
+
grouped_images, *paired_grouped_values, grouped_images_index = _group_images_by_shape(
|
| 1035 |
+
images, *paired_inputs, is_nested=is_nested
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
# Stack images with the same shape
|
| 1039 |
+
grouped_images = {shape: torch.stack(images_list, dim=0) for shape, images_list in grouped_images.items()}
|
| 1040 |
+
|
| 1041 |
+
return grouped_images, *paired_grouped_values, grouped_images_index
|
| 1042 |
+
|
| 1043 |
+
|
| 1044 |
+
def reorder_images(
|
| 1045 |
+
processed_images: dict[tuple[int, int], "torch.Tensor"],
|
| 1046 |
+
grouped_images_index: dict[int | tuple[int, int], tuple[tuple[int, int], int]],
|
| 1047 |
+
is_nested: bool = False,
|
| 1048 |
+
) -> Union[list["torch.Tensor"], "torch.Tensor"]:
|
| 1049 |
+
"""
|
| 1050 |
+
Reconstructs images in the original order, preserving the original structure (nested or not).
|
| 1051 |
+
The input structure is either all flat or all nested.
|
| 1052 |
+
|
| 1053 |
+
Args:
|
| 1054 |
+
processed_images (dict[tuple[int, int], "torch.Tensor"]):
|
| 1055 |
+
Dictionary mapping shapes to batched processed images.
|
| 1056 |
+
grouped_images_index (dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]):
|
| 1057 |
+
Dictionary mapping original indices to (shape, index) tuples.
|
| 1058 |
+
is_nested (bool, *optional*, defaults to False):
|
| 1059 |
+
Whether the images are nested. Cannot be inferred from the input, as some processing functions outputs nested images.
|
| 1060 |
+
even with non nested images,e.g functions splitting images into patches. We thus can't deduce is_nested from the input.
|
| 1061 |
+
|
| 1062 |
+
|
| 1063 |
+
Returns:
|
| 1064 |
+
Union[list["torch.Tensor"], "torch.Tensor"]:
|
| 1065 |
+
Images in the original structure.
|
| 1066 |
+
"""
|
| 1067 |
+
if not is_nested:
|
| 1068 |
+
return [
|
| 1069 |
+
processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
|
| 1070 |
+
for i in range(len(grouped_images_index))
|
| 1071 |
+
]
|
| 1072 |
+
|
| 1073 |
+
return _reconstruct_nested_structure(grouped_images_index, processed_images)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_gemma3 import *
|
| 22 |
+
from .image_processing_gemma3 import *
|
| 23 |
+
from .image_processing_pil_gemma3 import *
|
| 24 |
+
from .modeling_gemma3 import *
|
| 25 |
+
from .processing_gemma3 import *
|
| 26 |
+
else:
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
_file = globals()["__file__"]
|
| 30 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/configuration_gemma3.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_gemma3.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
from typing import Any
|
| 22 |
+
|
| 23 |
+
from huggingface_hub.dataclasses import strict
|
| 24 |
+
|
| 25 |
+
from ...configuration_utils import PreTrainedConfig
|
| 26 |
+
from ...utils import auto_docstring, logging
|
| 27 |
+
from ..siglip import SiglipVisionConfig
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@auto_docstring(checkpoint="google/gemma-3-4b-it")
|
| 34 |
+
@strict
|
| 35 |
+
class Gemma3TextConfig(PreTrainedConfig):
|
| 36 |
+
r"""
|
| 37 |
+
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
|
| 38 |
+
scaling factor used on the attention scores
|
| 39 |
+
final_logit_softcapping (`float`, *optional*):
|
| 40 |
+
Scaling factor when applying tanh softcapping on the logits.
|
| 41 |
+
attn_logit_softcapping (`float`, *optional*):
|
| 42 |
+
Scaling factor when applying tanh softcapping on the attention scores.
|
| 43 |
+
use_bidirectional_attention (`bool`, *optional*, defaults to `False`):
|
| 44 |
+
If True, the model will attend to all text tokens instead of using a causal mask. This does not change
|
| 45 |
+
behavior for vision tokens.
|
| 46 |
+
|
| 47 |
+
```python
|
| 48 |
+
>>> from transformers import Gemma3TextModel, Gemma3TextConfig
|
| 49 |
+
>>> # Initializing a Gemma3Text gemma3_text-7b style configuration
|
| 50 |
+
>>> configuration = Gemma3TextConfig()
|
| 51 |
+
>>> # Initializing a model from the gemma3_text-7b style configuration
|
| 52 |
+
>>> model = Gemma3TextModel(configuration)
|
| 53 |
+
>>> # Accessing the model configuration
|
| 54 |
+
>>> configuration = model.config
|
| 55 |
+
```
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
model_type = "gemma3_text"
|
| 59 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 60 |
+
base_model_tp_plan = {
|
| 61 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 62 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 63 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 64 |
+
"layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
|
| 65 |
+
"layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
|
| 66 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 67 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 68 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 69 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 70 |
+
}
|
| 71 |
+
base_model_pp_plan = {
|
| 72 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 73 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 74 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
vocab_size: int = 262_208
|
| 78 |
+
hidden_size: int = 2304
|
| 79 |
+
intermediate_size: int = 9216
|
| 80 |
+
num_hidden_layers: int = 26
|
| 81 |
+
num_attention_heads: int = 8
|
| 82 |
+
num_key_value_heads: int = 4
|
| 83 |
+
head_dim: int = 256
|
| 84 |
+
hidden_activation: str = "gelu_pytorch_tanh"
|
| 85 |
+
max_position_embeddings: int = 131_072
|
| 86 |
+
initializer_range: float = 0.02
|
| 87 |
+
rms_norm_eps: float = 1e-6
|
| 88 |
+
use_cache: bool = True
|
| 89 |
+
pad_token_id: int | None = 0
|
| 90 |
+
eos_token_id: int | list[int] | None = 1
|
| 91 |
+
bos_token_id: int | None = 2
|
| 92 |
+
tie_word_embeddings: bool = True
|
| 93 |
+
rope_parameters: dict | None = None
|
| 94 |
+
attention_bias: bool = False
|
| 95 |
+
attention_dropout: int | float | None = 0.0
|
| 96 |
+
query_pre_attn_scalar: int = 256
|
| 97 |
+
sliding_window: int | None = 4096
|
| 98 |
+
layer_types: list[str] | None = None
|
| 99 |
+
final_logit_softcapping: float | None = None
|
| 100 |
+
attn_logit_softcapping: float | None = None
|
| 101 |
+
use_bidirectional_attention: bool | None = False
|
| 102 |
+
default_theta = {"global": 1_000_000.0, "local": 10_000.0}
|
| 103 |
+
|
| 104 |
+
def __post_init__(self, **kwargs):
|
| 105 |
+
if self.use_bidirectional_attention:
|
| 106 |
+
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
|
| 107 |
+
|
| 108 |
+
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
|
| 109 |
+
self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
|
| 110 |
+
|
| 111 |
+
if self.layer_types is None:
|
| 112 |
+
self.layer_types = [
|
| 113 |
+
"sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
|
| 114 |
+
for i in range(self.num_hidden_layers)
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
super().__post_init__(**kwargs)
|
| 118 |
+
|
| 119 |
+
def validate_architecture(self):
|
| 120 |
+
"""Part of `@strict`-powered validation. Validates the architecture of the config."""
|
| 121 |
+
if self.hidden_size % self.num_attention_heads != 0:
|
| 122 |
+
raise ValueError(
|
| 123 |
+
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
|
| 124 |
+
f"heads ({self.num_attention_heads})."
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def convert_rope_params_to_dict(self, **kwargs):
|
| 128 |
+
rope_scaling = kwargs.pop("rope_scaling", None)
|
| 129 |
+
|
| 130 |
+
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`. If we find `rope_parameters`
|
| 131 |
+
# as arg in the inputs, we can safely assume that it is in the new format. New naming used -> new format
|
| 132 |
+
default_rope_params = {
|
| 133 |
+
"sliding_attention": {"rope_type": "default"},
|
| 134 |
+
"full_attention": {"rope_type": "default"},
|
| 135 |
+
}
|
| 136 |
+
self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else default_rope_params
|
| 137 |
+
if rope_scaling is not None:
|
| 138 |
+
self.rope_parameters["full_attention"].update(rope_scaling)
|
| 139 |
+
|
| 140 |
+
# Set default values if not present
|
| 141 |
+
if self.rope_parameters.get("full_attention") is None:
|
| 142 |
+
self.rope_parameters["full_attention"] = {"rope_type": "default"}
|
| 143 |
+
self.rope_parameters["full_attention"].setdefault(
|
| 144 |
+
"rope_theta", kwargs.pop("rope_theta", self.default_theta["global"])
|
| 145 |
+
)
|
| 146 |
+
if self.rope_parameters.get("sliding_attention") is None:
|
| 147 |
+
self.rope_parameters["sliding_attention"] = {"rope_type": "default"}
|
| 148 |
+
self.rope_parameters["sliding_attention"].setdefault(
|
| 149 |
+
"rope_theta", kwargs.pop("rope_local_base_freq", self.default_theta["local"])
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Standardize and validate the correctness of rotary position embeddings parameters
|
| 153 |
+
self.standardize_rope_params()
|
| 154 |
+
return kwargs
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@auto_docstring(checkpoint="google/gemma-3-4b-it")
|
| 158 |
+
@strict
|
| 159 |
+
class Gemma3Config(PreTrainedConfig):
|
| 160 |
+
r"""
|
| 161 |
+
mm_tokens_per_image (`int`, *optional*, defaults to 256):
|
| 162 |
+
The number of tokens per image embedding.
|
| 163 |
+
boi_token_index (`int`, *optional*, defaults to 255999):
|
| 164 |
+
The begin-of-image token index to wrap the image prompt.
|
| 165 |
+
eoi_token_index (`int`, *optional*, defaults to 256000):
|
| 166 |
+
The end-of-image token index to wrap the image prompt.
|
| 167 |
+
|
| 168 |
+
Example:
|
| 169 |
+
|
| 170 |
+
```python
|
| 171 |
+
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
|
| 172 |
+
|
| 173 |
+
>>> # Initializing a Siglip-like vision config
|
| 174 |
+
>>> vision_config = SiglipVisionConfig()
|
| 175 |
+
|
| 176 |
+
>>> # Initializing a Gemma3 Text config
|
| 177 |
+
>>> text_config = Gemma3TextConfig()
|
| 178 |
+
|
| 179 |
+
>>> # Initializing a Gemma3 gemma-3-4b style configuration
|
| 180 |
+
>>> configuration = Gemma3Config(vision_config, text_config)
|
| 181 |
+
|
| 182 |
+
>>> # Initializing a model from the gemma-3-4b style configuration
|
| 183 |
+
>>> model = Gemma3TextConfig(configuration)
|
| 184 |
+
|
| 185 |
+
>>> # Accessing the model configuration
|
| 186 |
+
>>> configuration = model.config
|
| 187 |
+
```"""
|
| 188 |
+
|
| 189 |
+
model_type = "gemma3"
|
| 190 |
+
attribute_map = {
|
| 191 |
+
"image_token_id": "image_token_index",
|
| 192 |
+
"boi_token_id": "boi_token_index",
|
| 193 |
+
"eoi_token_id": "eoi_token_index",
|
| 194 |
+
}
|
| 195 |
+
sub_configs = {
|
| 196 |
+
"text_config": Gemma3TextConfig,
|
| 197 |
+
"vision_config": SiglipVisionConfig,
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
text_config: Gemma3TextConfig | dict[str, Any] | None = None
|
| 201 |
+
vision_config: SiglipVisionConfig | dict[str, Any] | None = None
|
| 202 |
+
mm_tokens_per_image: int | None = 256
|
| 203 |
+
boi_token_index: int | None = 255_999
|
| 204 |
+
eoi_token_index: int | None = 256_000
|
| 205 |
+
image_token_index: int | None = 262_144
|
| 206 |
+
initializer_range: float | None = 0.02
|
| 207 |
+
tie_word_embeddings: bool | None = True
|
| 208 |
+
|
| 209 |
+
def __post_init__(self, **kwargs):
|
| 210 |
+
if self.text_config is None:
|
| 211 |
+
self.text_config = Gemma3TextConfig()
|
| 212 |
+
logger.info("text_config is None, using default Gemma3TextConfig text config.")
|
| 213 |
+
elif isinstance(self.text_config, dict):
|
| 214 |
+
self.text_config = Gemma3TextConfig(**self.text_config)
|
| 215 |
+
|
| 216 |
+
if isinstance(self.vision_config, dict):
|
| 217 |
+
self.vision_config = SiglipVisionConfig(**self.vision_config)
|
| 218 |
+
elif self.vision_config is None:
|
| 219 |
+
self.vision_config = SiglipVisionConfig()
|
| 220 |
+
logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
|
| 221 |
+
|
| 222 |
+
super().__post_init__(**kwargs)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
__all__ = ["Gemma3Config", "Gemma3TextConfig"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/image_processing_gemma3.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Image processor class for Gemma3."""
|
| 15 |
+
|
| 16 |
+
import itertools
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torchvision.transforms.v2 import functional as tvF
|
| 21 |
+
|
| 22 |
+
from ...image_processing_backends import TorchvisionBackend
|
| 23 |
+
from ...image_processing_utils import BatchFeature
|
| 24 |
+
from ...image_transforms import group_images_by_shape, reorder_images
|
| 25 |
+
from ...image_utils import (
|
| 26 |
+
IMAGENET_STANDARD_MEAN,
|
| 27 |
+
IMAGENET_STANDARD_STD,
|
| 28 |
+
ImageInput,
|
| 29 |
+
PILImageResampling,
|
| 30 |
+
SizeDict,
|
| 31 |
+
)
|
| 32 |
+
from ...processing_utils import ImagesKwargs, Unpack
|
| 33 |
+
from ...utils import (
|
| 34 |
+
TensorType,
|
| 35 |
+
auto_docstring,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Gemma3ImageProcessorKwargs(ImagesKwargs, total=False):
|
| 40 |
+
"""
|
| 41 |
+
do_pan_and_scan (`bool`, *optional*):
|
| 42 |
+
Whether to apply `pan_and_scan` to images.
|
| 43 |
+
pan_and_scan_min_crop_size (`int`, *optional*):
|
| 44 |
+
Minimum size of each crop in pan and scan.
|
| 45 |
+
pan_and_scan_max_num_crops (`int`, *optional*):
|
| 46 |
+
Maximum number of crops per image in pan and scan.
|
| 47 |
+
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
|
| 48 |
+
Minimum aspect ratio to activate pan and scan.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
do_pan_and_scan: bool
|
| 52 |
+
pan_and_scan_min_crop_size: int
|
| 53 |
+
pan_and_scan_max_num_crops: int
|
| 54 |
+
pan_and_scan_min_ratio_to_activate: float
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@auto_docstring
|
| 58 |
+
class Gemma3ImageProcessor(TorchvisionBackend):
|
| 59 |
+
resample = PILImageResampling.BILINEAR
|
| 60 |
+
image_mean = IMAGENET_STANDARD_MEAN
|
| 61 |
+
image_std = IMAGENET_STANDARD_STD
|
| 62 |
+
size = {"height": 224, "width": 224}
|
| 63 |
+
default_to_square = True
|
| 64 |
+
do_convert_rgb = True
|
| 65 |
+
do_resize = True
|
| 66 |
+
do_rescale = True
|
| 67 |
+
do_normalize = True
|
| 68 |
+
do_pan_and_scan = None
|
| 69 |
+
pan_and_scan_min_crop_size = None
|
| 70 |
+
pan_and_scan_max_num_crops = None
|
| 71 |
+
pan_and_scan_min_ratio_to_activate = None
|
| 72 |
+
valid_kwargs = Gemma3ImageProcessorKwargs
|
| 73 |
+
model_input_names = ["pixel_values", "num_crops"]
|
| 74 |
+
|
| 75 |
+
def __init__(self, **kwargs: Unpack[Gemma3ImageProcessorKwargs]):
|
| 76 |
+
super().__init__(**kwargs)
|
| 77 |
+
|
| 78 |
+
@auto_docstring
|
| 79 |
+
def preprocess(self, images: ImageInput, **kwargs: Unpack[Gemma3ImageProcessorKwargs]) -> BatchFeature:
|
| 80 |
+
return super().preprocess(images, **kwargs)
|
| 81 |
+
|
| 82 |
+
def pan_and_scan_batched(
|
| 83 |
+
self,
|
| 84 |
+
images: "torch.Tensor",
|
| 85 |
+
pan_and_scan_min_crop_size: int,
|
| 86 |
+
pan_and_scan_max_num_crops: int,
|
| 87 |
+
pan_and_scan_min_ratio_to_activate: float,
|
| 88 |
+
):
|
| 89 |
+
"""
|
| 90 |
+
Pan and Scan an image, by cropping into smaller images when the aspect ratio exceeds
|
| 91 |
+
minimum allowed ratio.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
images (`torch.Tensor`):
|
| 95 |
+
Image to resize.
|
| 96 |
+
pan_and_scan_min_crop_size (`int`, *optional*):
|
| 97 |
+
Minimum size of each crop in pan and scan.
|
| 98 |
+
pan_and_scan_max_num_crops (`int`, *optional*):
|
| 99 |
+
Maximum number of crops per image in pan and scan.
|
| 100 |
+
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
|
| 101 |
+
Minimum aspect ratio to activate pan and scan.
|
| 102 |
+
"""
|
| 103 |
+
height, width = images.shape[-2:]
|
| 104 |
+
|
| 105 |
+
# Square or landscape image.
|
| 106 |
+
if width >= height:
|
| 107 |
+
# Only apply PaS if the image is sufficiently exaggerated
|
| 108 |
+
if width / height < pan_and_scan_min_ratio_to_activate:
|
| 109 |
+
return []
|
| 110 |
+
|
| 111 |
+
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
|
| 112 |
+
num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding.
|
| 113 |
+
num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w)
|
| 114 |
+
|
| 115 |
+
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
|
| 116 |
+
num_crops_w = max(2, num_crops_w)
|
| 117 |
+
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
|
| 118 |
+
num_crops_h = 1
|
| 119 |
+
|
| 120 |
+
# Portrait image.
|
| 121 |
+
else:
|
| 122 |
+
# Only apply PaS if the image is sufficiently exaggerated
|
| 123 |
+
if height / width < pan_and_scan_min_ratio_to_activate:
|
| 124 |
+
return []
|
| 125 |
+
|
| 126 |
+
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
|
| 127 |
+
num_crops_h = int(math.floor(height / width + 0.5))
|
| 128 |
+
num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h)
|
| 129 |
+
|
| 130 |
+
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
|
| 131 |
+
num_crops_h = max(2, num_crops_h)
|
| 132 |
+
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
|
| 133 |
+
num_crops_w = 1
|
| 134 |
+
|
| 135 |
+
crop_size_w = int(math.ceil(width / num_crops_w))
|
| 136 |
+
crop_size_h = int(math.ceil(height / num_crops_h))
|
| 137 |
+
|
| 138 |
+
# Don't apply PaS if crop size is too small.
|
| 139 |
+
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
|
| 140 |
+
return []
|
| 141 |
+
|
| 142 |
+
crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
|
| 143 |
+
crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
|
| 144 |
+
|
| 145 |
+
return [
|
| 146 |
+
images[..., pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
|
| 147 |
+
for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
def _process_images_for_pan_and_scan(
|
| 151 |
+
self,
|
| 152 |
+
images: list["torch.Tensor"],
|
| 153 |
+
do_pan_and_scan: bool,
|
| 154 |
+
pan_and_scan_min_crop_size: int,
|
| 155 |
+
pan_and_scan_max_num_crops: int,
|
| 156 |
+
pan_and_scan_min_ratio_to_activate: float,
|
| 157 |
+
):
|
| 158 |
+
pas_images = self.pan_and_scan_batched(
|
| 159 |
+
images=images,
|
| 160 |
+
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
| 161 |
+
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
| 162 |
+
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
| 163 |
+
)
|
| 164 |
+
num_crops = [len(pas_images) for _ in images]
|
| 165 |
+
return pas_images, num_crops
|
| 166 |
+
|
| 167 |
+
def _preprocess(
|
| 168 |
+
self,
|
| 169 |
+
images: list["torch.Tensor"],
|
| 170 |
+
do_resize: bool,
|
| 171 |
+
size: SizeDict,
|
| 172 |
+
resample: "PILImageResampling | tvF.InterpolationMode | int | None",
|
| 173 |
+
do_rescale: bool,
|
| 174 |
+
rescale_factor: float,
|
| 175 |
+
do_normalize: bool,
|
| 176 |
+
image_mean: float | list[float] | None,
|
| 177 |
+
image_std: float | list[float] | None,
|
| 178 |
+
disable_grouping: bool | None,
|
| 179 |
+
return_tensors: str | TensorType | None,
|
| 180 |
+
do_pan_and_scan: bool | None = None,
|
| 181 |
+
pan_and_scan_min_crop_size: int | None = None,
|
| 182 |
+
pan_and_scan_max_num_crops: int | None = None,
|
| 183 |
+
pan_and_scan_min_ratio_to_activate: float | None = None,
|
| 184 |
+
**kwargs,
|
| 185 |
+
) -> BatchFeature:
|
| 186 |
+
# Group images by size for batched processing
|
| 187 |
+
processed_images_grouped = {}
|
| 188 |
+
num_crops_grouped = {}
|
| 189 |
+
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
| 190 |
+
for shape_images, stacked_images in grouped_images.items():
|
| 191 |
+
if do_pan_and_scan:
|
| 192 |
+
pas_images, num_crops = self._process_images_for_pan_and_scan(
|
| 193 |
+
images=stacked_images,
|
| 194 |
+
do_pan_and_scan=do_pan_and_scan,
|
| 195 |
+
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
| 196 |
+
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
| 197 |
+
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
| 198 |
+
)
|
| 199 |
+
# Add the thumbnails to the image patches
|
| 200 |
+
stacked_images = [stacked_images] + pas_images
|
| 201 |
+
# Group images by size for batched resizing (this will typically group thumbnails together and cropped patches together)
|
| 202 |
+
processed_image_patches_grouped = {}
|
| 203 |
+
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
|
| 204 |
+
stacked_images, disable_grouping=disable_grouping
|
| 205 |
+
)
|
| 206 |
+
for shape, stacked_image_patches in grouped_image_patches.items():
|
| 207 |
+
stacked_image_patches = self.resize(
|
| 208 |
+
image=stacked_image_patches,
|
| 209 |
+
size=size,
|
| 210 |
+
resample=resample,
|
| 211 |
+
)
|
| 212 |
+
processed_image_patches_grouped[shape] = stacked_image_patches
|
| 213 |
+
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
|
| 214 |
+
# Transpose to have the thumbnails with their corresponding patches
|
| 215 |
+
stacked_images = torch.stack(processed_image_patches, dim=0).transpose(0, 1).contiguous()
|
| 216 |
+
else:
|
| 217 |
+
num_crops = [0 for _ in stacked_images]
|
| 218 |
+
|
| 219 |
+
if do_resize:
|
| 220 |
+
stacked_images = self.resize(
|
| 221 |
+
image=stacked_images,
|
| 222 |
+
size=size,
|
| 223 |
+
resample=resample,
|
| 224 |
+
)
|
| 225 |
+
num_crops_grouped[shape_images] = num_crops
|
| 226 |
+
processed_images_grouped[shape_images] = stacked_images
|
| 227 |
+
resized_images = reorder_images(processed_images_grouped, grouped_images_index)
|
| 228 |
+
# If pan and scan is enabled, we need to flatten the list of images
|
| 229 |
+
if do_pan_and_scan:
|
| 230 |
+
resized_images = [image for images_list in resized_images for image in images_list]
|
| 231 |
+
num_crops = reorder_images(num_crops_grouped, grouped_images_index)
|
| 232 |
+
|
| 233 |
+
# Group images by size for further processing
|
| 234 |
+
# Needed in case do_resize is False, or resize returns images with different sizes
|
| 235 |
+
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
| 236 |
+
processed_images_grouped = {}
|
| 237 |
+
for shape, stacked_images in grouped_images.items():
|
| 238 |
+
# Fused rescale and normalize
|
| 239 |
+
stacked_images = self.rescale_and_normalize(
|
| 240 |
+
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
| 241 |
+
)
|
| 242 |
+
processed_images_grouped[shape] = stacked_images
|
| 243 |
+
|
| 244 |
+
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
| 245 |
+
return BatchFeature(
|
| 246 |
+
data={"pixel_values": processed_images, "num_crops": num_crops}, tensor_type=return_tensors
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
__all__ = ["Gemma3ImageProcessor"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/image_processing_pil_gemma3.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Image processor class for Gemma3."""
|
| 15 |
+
|
| 16 |
+
import itertools
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...image_processing_backends import PilBackend
|
| 22 |
+
from ...image_processing_utils import BatchFeature
|
| 23 |
+
from ...image_utils import (
|
| 24 |
+
IMAGENET_STANDARD_MEAN,
|
| 25 |
+
IMAGENET_STANDARD_STD,
|
| 26 |
+
ImageInput,
|
| 27 |
+
PILImageResampling,
|
| 28 |
+
SizeDict,
|
| 29 |
+
get_image_size,
|
| 30 |
+
)
|
| 31 |
+
from ...processing_utils import ImagesKwargs, Unpack
|
| 32 |
+
from ...utils import (
|
| 33 |
+
TensorType,
|
| 34 |
+
auto_docstring,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Adapted from transformers.models.gemma3.image_processing_gemma3.Gemma3ImageProcessorKwargs
|
| 39 |
+
class Gemma3ImageProcessorKwargs(ImagesKwargs, total=False):
|
| 40 |
+
"""
|
| 41 |
+
do_pan_and_scan (`bool`, *optional*):
|
| 42 |
+
Whether to apply `pan_and_scan` to images.
|
| 43 |
+
pan_and_scan_min_crop_size (`int`, *optional*):
|
| 44 |
+
Minimum size of each crop in pan and scan.
|
| 45 |
+
pan_and_scan_max_num_crops (`int`, *optional*):
|
| 46 |
+
Maximum number of crops per image in pan and scan.
|
| 47 |
+
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
|
| 48 |
+
Minimum aspect ratio to activate pan and scan.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
do_pan_and_scan: bool
|
| 52 |
+
pan_and_scan_min_crop_size: int
|
| 53 |
+
pan_and_scan_max_num_crops: int
|
| 54 |
+
pan_and_scan_min_ratio_to_activate: float
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@auto_docstring
|
| 58 |
+
class Gemma3ImageProcessorPil(PilBackend):
|
| 59 |
+
resample = PILImageResampling.BILINEAR
|
| 60 |
+
image_mean = IMAGENET_STANDARD_MEAN
|
| 61 |
+
image_std = IMAGENET_STANDARD_STD
|
| 62 |
+
size = {"height": 224, "width": 224}
|
| 63 |
+
default_to_square = True
|
| 64 |
+
do_convert_rgb = True
|
| 65 |
+
do_resize = True
|
| 66 |
+
do_rescale = True
|
| 67 |
+
do_normalize = True
|
| 68 |
+
do_pan_and_scan = None
|
| 69 |
+
pan_and_scan_min_crop_size = None
|
| 70 |
+
pan_and_scan_max_num_crops = None
|
| 71 |
+
pan_and_scan_min_ratio_to_activate = None
|
| 72 |
+
valid_kwargs = Gemma3ImageProcessorKwargs
|
| 73 |
+
model_input_names = ["pixel_values", "num_crops"]
|
| 74 |
+
|
| 75 |
+
def __init__(self, **kwargs: Unpack[Gemma3ImageProcessorKwargs]):
|
| 76 |
+
super().__init__(**kwargs)
|
| 77 |
+
|
| 78 |
+
@auto_docstring
|
| 79 |
+
def preprocess(self, images: ImageInput, **kwargs: Unpack[Gemma3ImageProcessorKwargs]) -> BatchFeature:
|
| 80 |
+
return super().preprocess(images, **kwargs)
|
| 81 |
+
|
| 82 |
+
def pan_and_scan(
|
| 83 |
+
self,
|
| 84 |
+
image: np.ndarray,
|
| 85 |
+
pan_and_scan_min_crop_size: int,
|
| 86 |
+
pan_and_scan_max_num_crops: int,
|
| 87 |
+
pan_and_scan_min_ratio_to_activate: float,
|
| 88 |
+
):
|
| 89 |
+
"""
|
| 90 |
+
Pan and Scan an image, by cropping into smaller images when the aspect ratio exceeds
|
| 91 |
+
minimum allowed ratio.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
image (`np.ndarray`):
|
| 95 |
+
Image to resize.
|
| 96 |
+
pan_and_scan_min_crop_size (`int`, *optional*):
|
| 97 |
+
Minimum size of each crop in pan and scan.
|
| 98 |
+
pan_and_scan_max_num_crops (`int`, *optional*):
|
| 99 |
+
Maximum number of crops per image in pan and scan.
|
| 100 |
+
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
|
| 101 |
+
Minimum aspect ratio to activate pan and scan.
|
| 102 |
+
"""
|
| 103 |
+
height, width = get_image_size(image, channel_dim="channels_first")
|
| 104 |
+
|
| 105 |
+
# Square or landscape image.
|
| 106 |
+
if width >= height:
|
| 107 |
+
# Only apply PaS if the image is sufficiently exaggerated
|
| 108 |
+
if width / height < pan_and_scan_min_ratio_to_activate:
|
| 109 |
+
return []
|
| 110 |
+
|
| 111 |
+
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
|
| 112 |
+
num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding.
|
| 113 |
+
num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w)
|
| 114 |
+
|
| 115 |
+
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
|
| 116 |
+
num_crops_w = max(2, num_crops_w)
|
| 117 |
+
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
|
| 118 |
+
num_crops_h = 1
|
| 119 |
+
|
| 120 |
+
# Portrait image.
|
| 121 |
+
else:
|
| 122 |
+
# Only apply PaS if the image is sufficiently exaggerated
|
| 123 |
+
if height / width < pan_and_scan_min_ratio_to_activate:
|
| 124 |
+
return []
|
| 125 |
+
|
| 126 |
+
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
|
| 127 |
+
num_crops_h = int(math.floor(height / width + 0.5))
|
| 128 |
+
num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h)
|
| 129 |
+
|
| 130 |
+
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
|
| 131 |
+
num_crops_h = max(2, num_crops_h)
|
| 132 |
+
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
|
| 133 |
+
num_crops_w = 1
|
| 134 |
+
|
| 135 |
+
crop_size_w = int(math.ceil(width / num_crops_w))
|
| 136 |
+
crop_size_h = int(math.ceil(height / num_crops_h))
|
| 137 |
+
|
| 138 |
+
# Don't apply PaS if crop size is too small.
|
| 139 |
+
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
|
| 140 |
+
return []
|
| 141 |
+
|
| 142 |
+
crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
|
| 143 |
+
crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
|
| 144 |
+
|
| 145 |
+
# Images are channels-first (CHW format)
|
| 146 |
+
return [
|
| 147 |
+
image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
|
| 148 |
+
for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
def _process_images_for_pan_and_scan(
|
| 152 |
+
self,
|
| 153 |
+
images: list[np.ndarray],
|
| 154 |
+
do_pan_and_scan: bool,
|
| 155 |
+
pan_and_scan_min_crop_size: int,
|
| 156 |
+
pan_and_scan_max_num_crops: int,
|
| 157 |
+
pan_and_scan_min_ratio_to_activate: float,
|
| 158 |
+
):
|
| 159 |
+
pas_images_list = []
|
| 160 |
+
num_crops = []
|
| 161 |
+
for image in images:
|
| 162 |
+
pas_images = self.pan_and_scan(
|
| 163 |
+
image=image,
|
| 164 |
+
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
| 165 |
+
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
| 166 |
+
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
| 167 |
+
)
|
| 168 |
+
pas_images_list.extend([image] + pas_images)
|
| 169 |
+
num_crops.append(len(pas_images))
|
| 170 |
+
return pas_images_list, num_crops
|
| 171 |
+
|
| 172 |
+
def _preprocess(
|
| 173 |
+
self,
|
| 174 |
+
images: list[np.ndarray],
|
| 175 |
+
do_resize: bool,
|
| 176 |
+
size: SizeDict,
|
| 177 |
+
resample: "PILImageResampling | None",
|
| 178 |
+
do_rescale: bool,
|
| 179 |
+
rescale_factor: float,
|
| 180 |
+
do_normalize: bool,
|
| 181 |
+
image_mean: float | list[float] | None,
|
| 182 |
+
image_std: float | list[float] | None,
|
| 183 |
+
return_tensors: str | TensorType | None,
|
| 184 |
+
do_pan_and_scan: bool | None = None,
|
| 185 |
+
pan_and_scan_min_crop_size: int | None = None,
|
| 186 |
+
pan_and_scan_max_num_crops: int | None = None,
|
| 187 |
+
pan_and_scan_min_ratio_to_activate: float | None = None,
|
| 188 |
+
**kwargs,
|
| 189 |
+
) -> BatchFeature:
|
| 190 |
+
processed_images = []
|
| 191 |
+
num_crops = []
|
| 192 |
+
|
| 193 |
+
for image in images:
|
| 194 |
+
if do_pan_and_scan:
|
| 195 |
+
pas_images = self.pan_and_scan(
|
| 196 |
+
image=image,
|
| 197 |
+
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
| 198 |
+
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
| 199 |
+
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
| 200 |
+
)
|
| 201 |
+
# Add the original image and its crops
|
| 202 |
+
image_list = [image] + pas_images
|
| 203 |
+
num_crops.append(len(pas_images))
|
| 204 |
+
else:
|
| 205 |
+
image_list = [image]
|
| 206 |
+
num_crops.append(0)
|
| 207 |
+
|
| 208 |
+
# Process each image (original + crops if pan_and_scan)
|
| 209 |
+
processed_image_list = []
|
| 210 |
+
for img in image_list:
|
| 211 |
+
if do_resize:
|
| 212 |
+
img = self.resize(image=img, size=size, resample=resample)
|
| 213 |
+
if do_rescale:
|
| 214 |
+
img = self.rescale(image=img, scale=rescale_factor)
|
| 215 |
+
if do_normalize:
|
| 216 |
+
img = self.normalize(image=img, mean=image_mean, std=image_std)
|
| 217 |
+
processed_image_list.append(img)
|
| 218 |
+
processed_images.extend(processed_image_list)
|
| 219 |
+
|
| 220 |
+
return BatchFeature(
|
| 221 |
+
data={"pixel_values": processed_images, "num_crops": num_crops}, tensor_type=return_tensors
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
__all__ = ["Gemma3ImageProcessorPil"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py
ADDED
|
@@ -0,0 +1,1118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_gemma3.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
from collections.abc import Callable
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
|
| 28 |
+
from ... import initialization as init
|
| 29 |
+
from ...activations import ACT2FN
|
| 30 |
+
from ...cache_utils import Cache, DynamicCache
|
| 31 |
+
from ...configuration_utils import PreTrainedConfig
|
| 32 |
+
from ...generation import GenerationMixin
|
| 33 |
+
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
| 34 |
+
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
|
| 35 |
+
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
|
| 36 |
+
from ...modeling_outputs import (
|
| 37 |
+
BaseModelOutputWithPast,
|
| 38 |
+
BaseModelOutputWithPooling,
|
| 39 |
+
CausalLMOutputWithPast,
|
| 40 |
+
SequenceClassifierOutputWithPast,
|
| 41 |
+
)
|
| 42 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 43 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 44 |
+
from ...processing_utils import Unpack
|
| 45 |
+
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
|
| 46 |
+
from ...utils.generic import maybe_autocast, merge_with_config_defaults
|
| 47 |
+
from ...utils.output_capturing import capture_outputs
|
| 48 |
+
from ..auto import AutoModel
|
| 49 |
+
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@auto_docstring(
|
| 53 |
+
custom_intro="""
|
| 54 |
+
Base class for Gemma3 outputs, with hidden states and attentions.
|
| 55 |
+
"""
|
| 56 |
+
)
|
| 57 |
+
@dataclass
|
| 58 |
+
class Gemma3ModelOutputWithPast(BaseModelOutputWithPast):
|
| 59 |
+
r"""
|
| 60 |
+
image_hidden_states (`torch.FloatTensor`, *optional*):
|
| 61 |
+
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
| 62 |
+
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
image_hidden_states: torch.FloatTensor | None = None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@auto_docstring(
|
| 69 |
+
custom_intro="""
|
| 70 |
+
Base class for Gemma3 causal language model (or autoregressive) outputs.
|
| 71 |
+
"""
|
| 72 |
+
)
|
| 73 |
+
@dataclass
|
| 74 |
+
class Gemma3CausalLMOutputWithPast(ModelOutput):
|
| 75 |
+
r"""
|
| 76 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 77 |
+
Language modeling loss (for next-token prediction).
|
| 78 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
| 79 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 80 |
+
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 81 |
+
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
| 82 |
+
|
| 83 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 84 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 85 |
+
image_hidden_states (`torch.FloatTensor`, *optional*):
|
| 86 |
+
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
| 87 |
+
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
loss: torch.FloatTensor | None = None
|
| 91 |
+
logits: torch.FloatTensor | None = None
|
| 92 |
+
past_key_values: Cache | None = None
|
| 93 |
+
hidden_states: tuple[torch.FloatTensor] | None = None
|
| 94 |
+
attentions: tuple[torch.FloatTensor] | None = None
|
| 95 |
+
image_hidden_states: torch.FloatTensor | None = None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Gemma3TextScaledWordEmbedding(nn.Embedding):
|
| 99 |
+
"""
|
| 100 |
+
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
| 104 |
+
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
| 105 |
+
self.scalar_embed_scale = embed_scale
|
| 106 |
+
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
| 107 |
+
|
| 108 |
+
def forward(self, input_ids: torch.Tensor):
|
| 109 |
+
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Gemma3MLP(nn.Module):
|
| 113 |
+
def __init__(self, config: Gemma3TextConfig):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.config = config
|
| 116 |
+
self.hidden_size = config.hidden_size
|
| 117 |
+
self.intermediate_size = config.intermediate_size
|
| 118 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 119 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 120 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 121 |
+
self.act_fn = ACT2FN[config.hidden_activation]
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 125 |
+
return down_proj
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class Gemma3RMSNorm(nn.Module):
|
| 129 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.eps = eps
|
| 132 |
+
self.weight = nn.Parameter(torch.zeros(dim))
|
| 133 |
+
|
| 134 |
+
def _norm(self, x):
|
| 135 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 136 |
+
|
| 137 |
+
def forward(self, x):
|
| 138 |
+
output = self._norm(x.float())
|
| 139 |
+
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
|
| 140 |
+
# See https://github.com/huggingface/transformers/pull/29402
|
| 141 |
+
output = output * (1.0 + self.weight.float())
|
| 142 |
+
return output.type_as(x)
|
| 143 |
+
|
| 144 |
+
def extra_repr(self):
|
| 145 |
+
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Gemma3RotaryEmbedding(nn.Module):
|
| 149 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 150 |
+
|
| 151 |
+
def __init__(self, config: Gemma3TextConfig):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 154 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 155 |
+
self.config = config
|
| 156 |
+
self.layer_types = list(set(config.layer_types))
|
| 157 |
+
self.rope_type = {}
|
| 158 |
+
for layer_type in self.layer_types:
|
| 159 |
+
rope_params = self.config.rope_parameters[layer_type]
|
| 160 |
+
if rope_params is None:
|
| 161 |
+
continue
|
| 162 |
+
|
| 163 |
+
self.rope_type[layer_type] = rope_params["rope_type"]
|
| 164 |
+
rope_init_fn: Callable = self.compute_default_rope_parameters
|
| 165 |
+
if self.rope_type[layer_type] != "default":
|
| 166 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
|
| 167 |
+
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, layer_type=layer_type)
|
| 168 |
+
self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
|
| 169 |
+
self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
|
| 170 |
+
setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
|
| 171 |
+
|
| 172 |
+
@staticmethod
|
| 173 |
+
def compute_default_rope_parameters(
|
| 174 |
+
config: Gemma3TextConfig | None = None,
|
| 175 |
+
device: Optional["torch.device"] = None,
|
| 176 |
+
seq_len: int | None = None,
|
| 177 |
+
layer_type: str | None = None,
|
| 178 |
+
) -> tuple["torch.Tensor", float]:
|
| 179 |
+
"""
|
| 180 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 181 |
+
Args:
|
| 182 |
+
config ([`~transformers.PreTrainedConfig`]):
|
| 183 |
+
The model configuration.
|
| 184 |
+
device (`torch.device`):
|
| 185 |
+
The device to use for initialization of the inverse frequencies.
|
| 186 |
+
seq_len (`int`, *optional*):
|
| 187 |
+
The current sequence length. Unused for this type of RoPE.
|
| 188 |
+
layer_type (`str`, *optional*):
|
| 189 |
+
The current layer type if the model has different RoPE parameters per type.
|
| 190 |
+
Should not be used unless `config.layer_types is not None`
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 194 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 195 |
+
"""
|
| 196 |
+
# For backward compatibility standardize the `rope_parameters_dict` if it uses old format
|
| 197 |
+
base = config.rope_parameters[layer_type]["rope_theta"]
|
| 198 |
+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 199 |
+
|
| 200 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 201 |
+
|
| 202 |
+
# Compute the inverse frequencies
|
| 203 |
+
inv_freq = 1.0 / (
|
| 204 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
| 205 |
+
)
|
| 206 |
+
return inv_freq, attention_factor
|
| 207 |
+
|
| 208 |
+
@torch.no_grad()
|
| 209 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 210 |
+
def forward(self, x, position_ids, layer_type=None):
|
| 211 |
+
inv_freq = getattr(self, f"{layer_type}_inv_freq")
|
| 212 |
+
attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
|
| 213 |
+
|
| 214 |
+
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 215 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 216 |
+
|
| 217 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 218 |
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
| 219 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 220 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 221 |
+
cos = emb.cos() * attention_scaling
|
| 222 |
+
sin = emb.sin() * attention_scaling
|
| 223 |
+
|
| 224 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def rotate_half(x):
|
| 228 |
+
"""Rotates half the hidden dims of the input."""
|
| 229 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 230 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 231 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@use_kernel_func_from_hub("rotary_pos_emb")
|
| 235 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
| 236 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
q (`torch.Tensor`): The query tensor.
|
| 240 |
+
k (`torch.Tensor`): The key tensor.
|
| 241 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 242 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 243 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 244 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 245 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 246 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 247 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 248 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 249 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 250 |
+
Returns:
|
| 251 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 252 |
+
"""
|
| 253 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 254 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 255 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 256 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 257 |
+
return q_embed, k_embed
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 261 |
+
"""
|
| 262 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 263 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 264 |
+
"""
|
| 265 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 266 |
+
if n_rep == 1:
|
| 267 |
+
return hidden_states
|
| 268 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 269 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def eager_attention_forward(
|
| 273 |
+
module: nn.Module,
|
| 274 |
+
query: torch.Tensor,
|
| 275 |
+
key: torch.Tensor,
|
| 276 |
+
value: torch.Tensor,
|
| 277 |
+
attention_mask: torch.Tensor | None,
|
| 278 |
+
dropout: float | int = 0.0,
|
| 279 |
+
scaling: float | None = None,
|
| 280 |
+
softcap: float | None = None,
|
| 281 |
+
**kwargs,
|
| 282 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 283 |
+
if scaling is None:
|
| 284 |
+
scaling = module.head_dim**-0.5
|
| 285 |
+
|
| 286 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 287 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 288 |
+
|
| 289 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 290 |
+
|
| 291 |
+
if softcap is not None:
|
| 292 |
+
attn_weights = attn_weights / softcap
|
| 293 |
+
attn_weights = torch.tanh(attn_weights)
|
| 294 |
+
attn_weights = attn_weights * softcap
|
| 295 |
+
if attention_mask is not None:
|
| 296 |
+
attn_weights = attn_weights + attention_mask
|
| 297 |
+
|
| 298 |
+
# upcast attention to fp32
|
| 299 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 300 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 301 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 302 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 303 |
+
return attn_output, attn_weights
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
@use_kernelized_func(apply_rotary_pos_emb)
|
| 307 |
+
class Gemma3Attention(nn.Module):
|
| 308 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 309 |
+
|
| 310 |
+
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
|
| 311 |
+
super().__init__()
|
| 312 |
+
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
|
| 313 |
+
self.config = config
|
| 314 |
+
self.layer_idx = layer_idx
|
| 315 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 316 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 317 |
+
self.scaling = config.query_pre_attn_scalar**-0.5
|
| 318 |
+
self.attention_dropout = self.config.attention_dropout
|
| 319 |
+
self.is_causal = not self.config.use_bidirectional_attention
|
| 320 |
+
|
| 321 |
+
self.q_proj = nn.Linear(
|
| 322 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 323 |
+
)
|
| 324 |
+
self.k_proj = nn.Linear(
|
| 325 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 326 |
+
)
|
| 327 |
+
self.v_proj = nn.Linear(
|
| 328 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 329 |
+
)
|
| 330 |
+
self.o_proj = nn.Linear(
|
| 331 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 332 |
+
)
|
| 333 |
+
self.attn_logit_softcapping = self.config.attn_logit_softcapping
|
| 334 |
+
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
| 335 |
+
self.is_sliding = self.layer_type == "sliding_attention"
|
| 336 |
+
|
| 337 |
+
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
| 338 |
+
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
| 339 |
+
|
| 340 |
+
def forward(
|
| 341 |
+
self,
|
| 342 |
+
hidden_states: torch.Tensor,
|
| 343 |
+
position_embeddings: torch.Tensor = None,
|
| 344 |
+
attention_mask: torch.Tensor | None = None,
|
| 345 |
+
past_key_values: Cache | None = None,
|
| 346 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 347 |
+
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
| 348 |
+
input_shape = hidden_states.shape[:-1]
|
| 349 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 350 |
+
|
| 351 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 352 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 353 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 354 |
+
|
| 355 |
+
query_states = self.q_norm(query_states)
|
| 356 |
+
key_states = self.k_norm(key_states)
|
| 357 |
+
|
| 358 |
+
cos, sin = position_embeddings
|
| 359 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 360 |
+
|
| 361 |
+
if past_key_values is not None:
|
| 362 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
| 363 |
+
|
| 364 |
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
| 365 |
+
self.config._attn_implementation, eager_attention_forward
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
attn_output, attn_weights = attention_interface(
|
| 369 |
+
self,
|
| 370 |
+
query_states,
|
| 371 |
+
key_states,
|
| 372 |
+
value_states,
|
| 373 |
+
attention_mask,
|
| 374 |
+
dropout=self.attention_dropout if self.training else 0.0,
|
| 375 |
+
scaling=self.scaling,
|
| 376 |
+
sliding_window=self.sliding_window,
|
| 377 |
+
**kwargs,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 381 |
+
attn_output = self.o_proj(attn_output)
|
| 382 |
+
return attn_output, attn_weights
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
| 386 |
+
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
|
| 387 |
+
super().__init__()
|
| 388 |
+
self.config = config
|
| 389 |
+
self.hidden_size = config.hidden_size
|
| 390 |
+
self.layer_idx = layer_idx
|
| 391 |
+
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
|
| 392 |
+
self.mlp = Gemma3MLP(config)
|
| 393 |
+
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 394 |
+
self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 395 |
+
self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 396 |
+
self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 397 |
+
|
| 398 |
+
def forward(
|
| 399 |
+
self,
|
| 400 |
+
hidden_states: torch.Tensor,
|
| 401 |
+
position_embeddings: torch.Tensor = None,
|
| 402 |
+
attention_mask: torch.Tensor | None = None,
|
| 403 |
+
position_ids: torch.LongTensor | None = None,
|
| 404 |
+
past_key_values: Cache | None = None,
|
| 405 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 406 |
+
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
| 407 |
+
residual = hidden_states
|
| 408 |
+
|
| 409 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 410 |
+
|
| 411 |
+
hidden_states, _ = self.self_attn(
|
| 412 |
+
hidden_states=hidden_states,
|
| 413 |
+
position_embeddings=position_embeddings,
|
| 414 |
+
attention_mask=attention_mask,
|
| 415 |
+
position_ids=position_ids,
|
| 416 |
+
past_key_values=past_key_values,
|
| 417 |
+
**kwargs,
|
| 418 |
+
)
|
| 419 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 420 |
+
hidden_states = residual + hidden_states
|
| 421 |
+
|
| 422 |
+
residual = hidden_states
|
| 423 |
+
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
| 424 |
+
hidden_states = self.mlp(hidden_states)
|
| 425 |
+
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
| 426 |
+
hidden_states = residual + hidden_states
|
| 427 |
+
|
| 428 |
+
return hidden_states
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
@auto_docstring
|
| 432 |
+
class Gemma3PreTrainedModel(PreTrainedModel):
|
| 433 |
+
config: Gemma3Config
|
| 434 |
+
base_model_prefix = "model"
|
| 435 |
+
supports_gradient_checkpointing = True
|
| 436 |
+
_no_split_modules = [
|
| 437 |
+
"Gemma3DecoderLayer",
|
| 438 |
+
"SiglipVisionEmbeddings",
|
| 439 |
+
"SiglipEncoderLayer",
|
| 440 |
+
"SiglipMultiheadAttentionPoolingHead",
|
| 441 |
+
]
|
| 442 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 443 |
+
_supports_flash_attn = True
|
| 444 |
+
_supports_sdpa = True
|
| 445 |
+
_supports_flex_attn = True
|
| 446 |
+
|
| 447 |
+
_can_compile_fullgraph = True
|
| 448 |
+
_supports_attention_backend = True
|
| 449 |
+
_can_record_outputs = {
|
| 450 |
+
"hidden_states": Gemma3DecoderLayer,
|
| 451 |
+
"attentions": Gemma3Attention,
|
| 452 |
+
}
|
| 453 |
+
input_modalities = ("image", "text")
|
| 454 |
+
|
| 455 |
+
@torch.no_grad()
|
| 456 |
+
def _init_weights(self, module):
|
| 457 |
+
super()._init_weights(module)
|
| 458 |
+
if isinstance(module, Gemma3MultiModalProjector):
|
| 459 |
+
init.zeros_(module.mm_input_projection_weight)
|
| 460 |
+
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
| 461 |
+
elif "RMSNorm" in module.__class__.__name__:
|
| 462 |
+
init.zeros_(module.weight)
|
| 463 |
+
elif isinstance(module, Gemma3TextScaledWordEmbedding):
|
| 464 |
+
init.constant_(module.embed_scale, module.scalar_embed_scale)
|
| 465 |
+
elif isinstance(module, Gemma3RotaryEmbedding):
|
| 466 |
+
for layer_type in module.layer_types:
|
| 467 |
+
rope_init_fn = module.compute_default_rope_parameters
|
| 468 |
+
if module.rope_type[layer_type] != "default":
|
| 469 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
| 470 |
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
| 471 |
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
| 472 |
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
|
| 476 |
+
"""
|
| 477 |
+
Enables a bidirectional mask within the sliding window.
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
| 481 |
+
"""A token can attend to any other token if their absolute distance is within
|
| 482 |
+
the (exclusive) sliding window size (distance < sliding_window)."""
|
| 483 |
+
return abs(q_idx - kv_idx) < sliding_window
|
| 484 |
+
|
| 485 |
+
return inner_mask
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
@auto_docstring
|
| 489 |
+
class Gemma3TextModel(Gemma3PreTrainedModel):
|
| 490 |
+
config: Gemma3TextConfig
|
| 491 |
+
input_modalities = ("text",)
|
| 492 |
+
|
| 493 |
+
def __init__(self, config: Gemma3TextConfig):
|
| 494 |
+
super().__init__(config)
|
| 495 |
+
self.padding_idx = config.pad_token_id
|
| 496 |
+
self.vocab_size = config.vocab_size
|
| 497 |
+
|
| 498 |
+
# Gemma3 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
|
| 499 |
+
self.embed_tokens = Gemma3TextScaledWordEmbedding(
|
| 500 |
+
config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
|
| 501 |
+
)
|
| 502 |
+
self.layers = nn.ModuleList(
|
| 503 |
+
[Gemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 504 |
+
)
|
| 505 |
+
self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 506 |
+
self.rotary_emb = Gemma3RotaryEmbedding(config)
|
| 507 |
+
self.gradient_checkpointing = False
|
| 508 |
+
|
| 509 |
+
# Initialize weights and apply final processing
|
| 510 |
+
self.post_init()
|
| 511 |
+
|
| 512 |
+
@merge_with_config_defaults
|
| 513 |
+
@capture_outputs
|
| 514 |
+
@auto_docstring
|
| 515 |
+
def forward(
|
| 516 |
+
self,
|
| 517 |
+
input_ids: torch.LongTensor | None = None,
|
| 518 |
+
attention_mask: torch.Tensor | None = None,
|
| 519 |
+
position_ids: torch.LongTensor | None = None,
|
| 520 |
+
past_key_values: Cache | None = None,
|
| 521 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 522 |
+
use_cache: bool | None = None,
|
| 523 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 524 |
+
) -> BaseModelOutputWithPast:
|
| 525 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 526 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 527 |
+
|
| 528 |
+
if inputs_embeds is None:
|
| 529 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 530 |
+
|
| 531 |
+
if use_cache and past_key_values is None:
|
| 532 |
+
past_key_values = DynamicCache(config=self.config)
|
| 533 |
+
|
| 534 |
+
if position_ids is None:
|
| 535 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 536 |
+
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
|
| 537 |
+
position_ids = position_ids.unsqueeze(0)
|
| 538 |
+
|
| 539 |
+
# It may already have been prepared by e.g. `generate`
|
| 540 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 541 |
+
# Prepare mask arguments
|
| 542 |
+
mask_kwargs = {
|
| 543 |
+
"config": self.config,
|
| 544 |
+
"inputs_embeds": inputs_embeds,
|
| 545 |
+
"attention_mask": attention_mask,
|
| 546 |
+
"past_key_values": past_key_values,
|
| 547 |
+
"position_ids": position_ids,
|
| 548 |
+
}
|
| 549 |
+
sliding_mask_kwargs = mask_kwargs.copy()
|
| 550 |
+
|
| 551 |
+
if self.config.use_bidirectional_attention:
|
| 552 |
+
mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
|
| 553 |
+
sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window)
|
| 554 |
+
|
| 555 |
+
# Create the masks
|
| 556 |
+
causal_mask_mapping = {
|
| 557 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
| 558 |
+
"sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
# embed positions
|
| 562 |
+
hidden_states = inputs_embeds
|
| 563 |
+
position_embeddings = {}
|
| 564 |
+
for layer_type in set(self.config.layer_types):
|
| 565 |
+
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
| 566 |
+
|
| 567 |
+
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
|
| 568 |
+
hidden_states = decoder_layer(
|
| 569 |
+
hidden_states,
|
| 570 |
+
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
|
| 571 |
+
position_embeddings=position_embeddings[self.config.layer_types[i]],
|
| 572 |
+
position_ids=position_ids,
|
| 573 |
+
past_key_values=past_key_values,
|
| 574 |
+
**kwargs,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
hidden_states = self.norm(hidden_states)
|
| 578 |
+
|
| 579 |
+
return BaseModelOutputWithPast(
|
| 580 |
+
last_hidden_state=hidden_states,
|
| 581 |
+
past_key_values=past_key_values,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
@auto_docstring
|
| 586 |
+
class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
| 587 |
+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 588 |
+
_tp_plan = {"lm_head": "colwise_gather_output"}
|
| 589 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 590 |
+
config: Gemma3TextConfig
|
| 591 |
+
|
| 592 |
+
def __init__(self, config: Gemma3TextConfig):
|
| 593 |
+
super().__init__(config)
|
| 594 |
+
self.model = Gemma3TextModel(config)
|
| 595 |
+
self.vocab_size = config.vocab_size
|
| 596 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 597 |
+
|
| 598 |
+
# Initialize weights and apply final processing
|
| 599 |
+
self.post_init()
|
| 600 |
+
|
| 601 |
+
@can_return_tuple
|
| 602 |
+
@auto_docstring
|
| 603 |
+
def forward(
|
| 604 |
+
self,
|
| 605 |
+
input_ids: torch.LongTensor | None = None,
|
| 606 |
+
attention_mask: torch.Tensor | None = None,
|
| 607 |
+
position_ids: torch.LongTensor | None = None,
|
| 608 |
+
past_key_values: Cache | None = None,
|
| 609 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 610 |
+
labels: torch.LongTensor | None = None,
|
| 611 |
+
use_cache: bool | None = None,
|
| 612 |
+
logits_to_keep: int | torch.Tensor = 0,
|
| 613 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 614 |
+
) -> CausalLMOutputWithPast:
|
| 615 |
+
r"""
|
| 616 |
+
Example:
|
| 617 |
+
|
| 618 |
+
```python
|
| 619 |
+
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
|
| 620 |
+
|
| 621 |
+
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
|
| 622 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
| 623 |
+
|
| 624 |
+
>>> prompt = "What is your favorite condiment?"
|
| 625 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 626 |
+
|
| 627 |
+
>>> # Generate
|
| 628 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 629 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 630 |
+
"What is your favorite condiment?"
|
| 631 |
+
```"""
|
| 632 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 633 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 634 |
+
input_ids=input_ids,
|
| 635 |
+
attention_mask=attention_mask,
|
| 636 |
+
position_ids=position_ids,
|
| 637 |
+
past_key_values=past_key_values,
|
| 638 |
+
inputs_embeds=inputs_embeds,
|
| 639 |
+
use_cache=use_cache,
|
| 640 |
+
**kwargs,
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
hidden_states = outputs.last_hidden_state
|
| 644 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 645 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 646 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 647 |
+
if self.config.final_logit_softcapping is not None:
|
| 648 |
+
logits = logits / self.config.final_logit_softcapping
|
| 649 |
+
logits = torch.tanh(logits)
|
| 650 |
+
logits = logits * self.config.final_logit_softcapping
|
| 651 |
+
|
| 652 |
+
loss = None
|
| 653 |
+
if labels is not None:
|
| 654 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
| 655 |
+
|
| 656 |
+
return CausalLMOutputWithPast(
|
| 657 |
+
loss=loss,
|
| 658 |
+
logits=logits,
|
| 659 |
+
past_key_values=outputs.past_key_values,
|
| 660 |
+
hidden_states=outputs.hidden_states,
|
| 661 |
+
attentions=outputs.attentions,
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
class Gemma3MultiModalProjector(nn.Module):
|
| 666 |
+
def __init__(self, config: Gemma3Config):
|
| 667 |
+
super().__init__()
|
| 668 |
+
|
| 669 |
+
self.mm_input_projection_weight = nn.Parameter(
|
| 670 |
+
torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
self.mm_soft_emb_norm = Gemma3RMSNorm(
|
| 674 |
+
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
|
| 678 |
+
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
|
| 679 |
+
self.kernel_size = self.patches_per_image // self.tokens_per_side
|
| 680 |
+
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
|
| 681 |
+
|
| 682 |
+
def forward(self, vision_outputs: torch.Tensor):
|
| 683 |
+
batch_size, _, hidden_size = vision_outputs.shape
|
| 684 |
+
|
| 685 |
+
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
|
| 686 |
+
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
|
| 687 |
+
batch_size, hidden_size, self.patches_per_image, self.patches_per_image
|
| 688 |
+
)
|
| 689 |
+
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
| 690 |
+
|
| 691 |
+
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
| 692 |
+
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
|
| 693 |
+
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
|
| 694 |
+
|
| 695 |
+
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
|
| 696 |
+
|
| 697 |
+
projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
|
| 698 |
+
return projected_vision_outputs.type_as(vision_outputs)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def get_block_sequence_ids_for_mask(token_type_ids: torch.Tensor, device: torch.device | None = None) -> torch.Tensor:
|
| 702 |
+
# First find where a new image block starts: 1 if image and previous not image
|
| 703 |
+
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
| 704 |
+
is_image = (token_type_ids == 1).to(device=device)
|
| 705 |
+
is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
| 706 |
+
new_image_start = is_image & ~is_previous_image
|
| 707 |
+
group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
| 708 |
+
block_sequence_ids = torch.where(is_image, group_ids, -1)
|
| 709 |
+
return block_sequence_ids
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
@auto_docstring(
|
| 713 |
+
custom_intro="""
|
| 714 |
+
The Base Gemma3 model which consists of a vision backbone and a language model without language modeling head.,
|
| 715 |
+
"""
|
| 716 |
+
)
|
| 717 |
+
class Gemma3Model(Gemma3PreTrainedModel):
|
| 718 |
+
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
| 719 |
+
accepts_loss_kwargs = False
|
| 720 |
+
|
| 721 |
+
def __init__(self, config: Gemma3Config):
|
| 722 |
+
super().__init__(config)
|
| 723 |
+
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
| 724 |
+
self.multi_modal_projector = Gemma3MultiModalProjector(config)
|
| 725 |
+
self.vocab_size = config.text_config.vocab_size
|
| 726 |
+
|
| 727 |
+
language_model = AutoModel.from_config(config=config.text_config)
|
| 728 |
+
self.language_model = language_model
|
| 729 |
+
self.post_init()
|
| 730 |
+
|
| 731 |
+
@can_return_tuple
|
| 732 |
+
@auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
|
| 733 |
+
def get_image_features(
|
| 734 |
+
self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
|
| 735 |
+
) -> tuple | BaseModelOutputWithPooling:
|
| 736 |
+
vision_outputs = self.vision_tower(pixel_values=pixel_values, return_dict=True, **kwargs)
|
| 737 |
+
last_hidden_state = vision_outputs.last_hidden_state
|
| 738 |
+
vision_outputs.pooler_output = self.multi_modal_projector(last_hidden_state)
|
| 739 |
+
|
| 740 |
+
return vision_outputs
|
| 741 |
+
|
| 742 |
+
def get_placeholder_mask(
|
| 743 |
+
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
|
| 744 |
+
):
|
| 745 |
+
"""
|
| 746 |
+
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
|
| 747 |
+
equal to the length of multimodal features. If the lengths are different, an error is raised.
|
| 748 |
+
"""
|
| 749 |
+
if input_ids is None:
|
| 750 |
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
| 751 |
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 752 |
+
)
|
| 753 |
+
special_image_mask = special_image_mask.all(-1)
|
| 754 |
+
else:
|
| 755 |
+
special_image_mask = input_ids == self.config.image_token_id
|
| 756 |
+
|
| 757 |
+
n_image_tokens = special_image_mask.sum()
|
| 758 |
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
| 759 |
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 760 |
+
torch_compilable_check(
|
| 761 |
+
inputs_embeds[special_image_mask].numel() == image_features.numel(),
|
| 762 |
+
f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
|
| 763 |
+
)
|
| 764 |
+
return special_image_mask
|
| 765 |
+
|
| 766 |
+
@can_return_tuple
|
| 767 |
+
@auto_docstring
|
| 768 |
+
def forward(
|
| 769 |
+
self,
|
| 770 |
+
input_ids: torch.LongTensor | None = None,
|
| 771 |
+
pixel_values: torch.FloatTensor | None = None,
|
| 772 |
+
attention_mask: torch.Tensor | None = None,
|
| 773 |
+
position_ids: torch.LongTensor | None = None,
|
| 774 |
+
past_key_values: Cache | None = None,
|
| 775 |
+
token_type_ids: torch.LongTensor | None = None,
|
| 776 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 777 |
+
labels: torch.LongTensor | None = None,
|
| 778 |
+
use_cache: bool | None = None,
|
| 779 |
+
**lm_kwargs: Unpack[TransformersKwargs],
|
| 780 |
+
) -> tuple | Gemma3ModelOutputWithPast:
|
| 781 |
+
r"""
|
| 782 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 783 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 784 |
+
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 785 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
| 786 |
+
|
| 787 |
+
Example:
|
| 788 |
+
|
| 789 |
+
```python
|
| 790 |
+
>>> from PIL import Image
|
| 791 |
+
>>> import httpx
|
| 792 |
+
>>> from io import BytesIO
|
| 793 |
+
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
| 794 |
+
|
| 795 |
+
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma32-3b-mix-224")
|
| 796 |
+
>>> processor = AutoProcessor.from_pretrained("google/gemma32-3b-mix-224")
|
| 797 |
+
|
| 798 |
+
>>> prompt = "Where is the cat standing?"
|
| 799 |
+
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
| 800 |
+
>>> with httpx.stream("GET", url) as response:
|
| 801 |
+
... image = Image.open(BytesIO(response.read()))
|
| 802 |
+
|
| 803 |
+
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
| 804 |
+
|
| 805 |
+
>>> # Generate
|
| 806 |
+
>>> generate_ids = model.generate(**inputs,)
|
| 807 |
+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 808 |
+
"Where is the cat standing?\nsnow"
|
| 809 |
+
```"""
|
| 810 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 811 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 812 |
+
|
| 813 |
+
# Replace image id with PAD if the image token if OOV, to avoid index-errors
|
| 814 |
+
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
| 815 |
+
special_image_mask = input_ids == self.config.image_token_id
|
| 816 |
+
llm_input_ids = input_ids.clone()
|
| 817 |
+
llm_input_ids[special_image_mask] = 0
|
| 818 |
+
else:
|
| 819 |
+
llm_input_ids = input_ids
|
| 820 |
+
|
| 821 |
+
if inputs_embeds is None:
|
| 822 |
+
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
| 823 |
+
|
| 824 |
+
# Merge text and images
|
| 825 |
+
if pixel_values is not None:
|
| 826 |
+
image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
|
| 827 |
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 828 |
+
special_image_mask = self.get_placeholder_mask(
|
| 829 |
+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
| 830 |
+
)
|
| 831 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
| 832 |
+
|
| 833 |
+
# It may already have been prepared by e.g. `generate`
|
| 834 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 835 |
+
mask_kwargs = {
|
| 836 |
+
"config": self.config.get_text_config(),
|
| 837 |
+
"inputs_embeds": inputs_embeds,
|
| 838 |
+
"attention_mask": attention_mask,
|
| 839 |
+
"past_key_values": past_key_values,
|
| 840 |
+
"position_ids": position_ids,
|
| 841 |
+
}
|
| 842 |
+
|
| 843 |
+
if token_type_ids is not None:
|
| 844 |
+
mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask(
|
| 845 |
+
token_type_ids, device=inputs_embeds.device
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
+
# Create the masks
|
| 849 |
+
sliding_mask_kwargs = mask_kwargs.copy()
|
| 850 |
+
causal_mask_mapping = {
|
| 851 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
| 852 |
+
"sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
|
| 853 |
+
}
|
| 854 |
+
|
| 855 |
+
outputs = self.language_model(
|
| 856 |
+
attention_mask=causal_mask_mapping,
|
| 857 |
+
position_ids=position_ids,
|
| 858 |
+
past_key_values=past_key_values,
|
| 859 |
+
inputs_embeds=inputs_embeds,
|
| 860 |
+
use_cache=use_cache,
|
| 861 |
+
return_dict=True,
|
| 862 |
+
**lm_kwargs,
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
return Gemma3ModelOutputWithPast(
|
| 866 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 867 |
+
past_key_values=outputs.past_key_values,
|
| 868 |
+
hidden_states=outputs.hidden_states,
|
| 869 |
+
attentions=outputs.attentions,
|
| 870 |
+
image_hidden_states=image_features if pixel_values is not None else None,
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
@auto_docstring(
|
| 875 |
+
custom_intro="""
|
| 876 |
+
The Base Gemma3 model which consists of a vision backbone and a language model without language modeling head.,
|
| 877 |
+
"""
|
| 878 |
+
)
|
| 879 |
+
class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
| 880 |
+
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
| 881 |
+
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
| 882 |
+
# Fix: https://github.com/huggingface/transformers/issues/40564
|
| 883 |
+
accepts_loss_kwargs = False
|
| 884 |
+
|
| 885 |
+
def __init__(self, config: Gemma3Config):
|
| 886 |
+
super().__init__(config)
|
| 887 |
+
self.model = Gemma3Model(config)
|
| 888 |
+
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
| 889 |
+
self.post_init()
|
| 890 |
+
|
| 891 |
+
@auto_docstring
|
| 892 |
+
def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]):
|
| 893 |
+
return self.model.get_image_features(pixel_values, **kwargs)
|
| 894 |
+
|
| 895 |
+
@can_return_tuple
|
| 896 |
+
@auto_docstring
|
| 897 |
+
def forward(
|
| 898 |
+
self,
|
| 899 |
+
input_ids: torch.LongTensor | None = None,
|
| 900 |
+
pixel_values: torch.FloatTensor | None = None,
|
| 901 |
+
attention_mask: torch.Tensor | None = None,
|
| 902 |
+
position_ids: torch.LongTensor | None = None,
|
| 903 |
+
past_key_values: Cache | None = None,
|
| 904 |
+
token_type_ids: torch.LongTensor | None = None,
|
| 905 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 906 |
+
labels: torch.LongTensor | None = None,
|
| 907 |
+
use_cache: bool | None = None,
|
| 908 |
+
logits_to_keep: int | torch.Tensor = 0,
|
| 909 |
+
**lm_kwargs: Unpack[TransformersKwargs],
|
| 910 |
+
) -> tuple | Gemma3CausalLMOutputWithPast:
|
| 911 |
+
r"""
|
| 912 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 913 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 914 |
+
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 915 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
| 916 |
+
|
| 917 |
+
Example:
|
| 918 |
+
|
| 919 |
+
```python
|
| 920 |
+
>>> from PIL import Image
|
| 921 |
+
>>> import httpx
|
| 922 |
+
>>> from io import BytesIO
|
| 923 |
+
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
| 924 |
+
|
| 925 |
+
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
|
| 926 |
+
>>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
| 927 |
+
|
| 928 |
+
>>> messages = [
|
| 929 |
+
... {
|
| 930 |
+
... "role": "system",
|
| 931 |
+
... "content": [
|
| 932 |
+
... {"type": "text", "text": "You are a helpful assistant."}
|
| 933 |
+
... ]
|
| 934 |
+
... },
|
| 935 |
+
... {
|
| 936 |
+
... "role": "user", "content": [
|
| 937 |
+
... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
| 938 |
+
... {"type": "text", "text": "Where is the cat standing?"},
|
| 939 |
+
... ]
|
| 940 |
+
... },
|
| 941 |
+
... ]
|
| 942 |
+
|
| 943 |
+
>>> inputs = processor.apply_chat_template(
|
| 944 |
+
... messages,
|
| 945 |
+
... tokenize=True,
|
| 946 |
+
... return_dict=True,
|
| 947 |
+
... return_tensors="pt",
|
| 948 |
+
... add_generation_prompt=True
|
| 949 |
+
... )
|
| 950 |
+
>>> # Generate
|
| 951 |
+
>>> generate_ids = model.generate(**inputs)
|
| 952 |
+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 953 |
+
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
|
| 954 |
+
```
|
| 955 |
+
"""
|
| 956 |
+
outputs = self.model(
|
| 957 |
+
input_ids=input_ids,
|
| 958 |
+
pixel_values=pixel_values,
|
| 959 |
+
token_type_ids=token_type_ids,
|
| 960 |
+
attention_mask=attention_mask,
|
| 961 |
+
position_ids=position_ids,
|
| 962 |
+
past_key_values=past_key_values,
|
| 963 |
+
inputs_embeds=inputs_embeds,
|
| 964 |
+
use_cache=use_cache,
|
| 965 |
+
labels=labels,
|
| 966 |
+
return_dict=True,
|
| 967 |
+
**lm_kwargs,
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
hidden_states = outputs[0]
|
| 971 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 972 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 973 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 974 |
+
|
| 975 |
+
loss = None
|
| 976 |
+
if labels is not None:
|
| 977 |
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
| 978 |
+
logits = logits.float()
|
| 979 |
+
shift_logits = logits[..., :-1, :]
|
| 980 |
+
shift_labels = labels[..., 1:]
|
| 981 |
+
if attention_mask is not None:
|
| 982 |
+
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
| 983 |
+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
| 984 |
+
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
|
| 985 |
+
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
| 986 |
+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
| 987 |
+
else:
|
| 988 |
+
shift_logits = shift_logits.contiguous()
|
| 989 |
+
shift_labels = shift_labels.contiguous()
|
| 990 |
+
# Flatten the tokens
|
| 991 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 992 |
+
|
| 993 |
+
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
| 994 |
+
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
| 995 |
+
loss = loss_fct(flat_logits, flat_labels)
|
| 996 |
+
|
| 997 |
+
return Gemma3CausalLMOutputWithPast(
|
| 998 |
+
loss=loss,
|
| 999 |
+
logits=logits,
|
| 1000 |
+
past_key_values=outputs.past_key_values,
|
| 1001 |
+
hidden_states=outputs.hidden_states,
|
| 1002 |
+
attentions=outputs.attentions,
|
| 1003 |
+
image_hidden_states=outputs.image_hidden_states,
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
def prepare_inputs_for_generation(
|
| 1007 |
+
self,
|
| 1008 |
+
input_ids,
|
| 1009 |
+
past_key_values=None,
|
| 1010 |
+
inputs_embeds=None,
|
| 1011 |
+
position_ids=None,
|
| 1012 |
+
pixel_values=None,
|
| 1013 |
+
attention_mask=None,
|
| 1014 |
+
token_type_ids=None,
|
| 1015 |
+
use_cache=True,
|
| 1016 |
+
logits_to_keep=None,
|
| 1017 |
+
labels=None,
|
| 1018 |
+
is_first_iteration=False,
|
| 1019 |
+
**kwargs,
|
| 1020 |
+
):
|
| 1021 |
+
# Overwritten -- custom `pixel_values` handling
|
| 1022 |
+
model_inputs = super().prepare_inputs_for_generation(
|
| 1023 |
+
input_ids,
|
| 1024 |
+
past_key_values=past_key_values,
|
| 1025 |
+
inputs_embeds=inputs_embeds,
|
| 1026 |
+
attention_mask=attention_mask,
|
| 1027 |
+
position_ids=position_ids,
|
| 1028 |
+
use_cache=use_cache,
|
| 1029 |
+
logits_to_keep=logits_to_keep,
|
| 1030 |
+
token_type_ids=token_type_ids,
|
| 1031 |
+
is_first_iteration=is_first_iteration,
|
| 1032 |
+
**kwargs,
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
# Pixel values are used only in the first iteration if available
|
| 1036 |
+
# In subsequent iterations, they are already merged with text and cached
|
| 1037 |
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
| 1038 |
+
# iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
|
| 1039 |
+
if is_first_iteration or not use_cache:
|
| 1040 |
+
model_inputs["pixel_values"] = pixel_values
|
| 1041 |
+
else:
|
| 1042 |
+
# Don't pass to not apply bidirectional mask on top
|
| 1043 |
+
model_inputs["token_type_ids"] = None
|
| 1044 |
+
|
| 1045 |
+
return model_inputs
|
| 1046 |
+
|
| 1047 |
+
@staticmethod
|
| 1048 |
+
def create_masks_for_generate(
|
| 1049 |
+
config: PreTrainedConfig,
|
| 1050 |
+
inputs_embeds: torch.Tensor,
|
| 1051 |
+
attention_mask: torch.Tensor | None,
|
| 1052 |
+
past_key_values: Cache | None,
|
| 1053 |
+
position_ids: torch.Tensor | None,
|
| 1054 |
+
token_type_ids: torch.Tensor | None = None,
|
| 1055 |
+
is_first_iteration: bool | None = False,
|
| 1056 |
+
**kwargs,
|
| 1057 |
+
) -> dict:
|
| 1058 |
+
mask_kwargs = {
|
| 1059 |
+
"config": config.get_text_config(),
|
| 1060 |
+
"inputs_embeds": inputs_embeds,
|
| 1061 |
+
"attention_mask": attention_mask,
|
| 1062 |
+
"past_key_values": past_key_values,
|
| 1063 |
+
"position_ids": position_ids,
|
| 1064 |
+
}
|
| 1065 |
+
|
| 1066 |
+
if token_type_ids is not None:
|
| 1067 |
+
mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask(
|
| 1068 |
+
token_type_ids, device=inputs_embeds.device
|
| 1069 |
+
)
|
| 1070 |
+
|
| 1071 |
+
return create_masks_for_generate(**mask_kwargs)
|
| 1072 |
+
|
| 1073 |
+
|
| 1074 |
+
@auto_docstring(
|
| 1075 |
+
custom_intro="""
|
| 1076 |
+
Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig.
|
| 1077 |
+
It uses the generic sequence classification implementation for efficiency and consistency."""
|
| 1078 |
+
)
|
| 1079 |
+
class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
|
| 1080 |
+
config: Gemma3TextConfig
|
| 1081 |
+
input_modalities = ("text",)
|
| 1082 |
+
|
| 1083 |
+
|
| 1084 |
+
class Gemma3ForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
|
| 1085 |
+
def forward(
|
| 1086 |
+
self,
|
| 1087 |
+
input_ids: torch.LongTensor | None = None,
|
| 1088 |
+
pixel_values: torch.FloatTensor | None = None,
|
| 1089 |
+
attention_mask: torch.Tensor | None = None,
|
| 1090 |
+
position_ids: torch.LongTensor | None = None,
|
| 1091 |
+
past_key_values: Cache | None = None,
|
| 1092 |
+
token_type_ids: torch.LongTensor | None = None,
|
| 1093 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 1094 |
+
labels: torch.LongTensor | None = None,
|
| 1095 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1096 |
+
) -> SequenceClassifierOutputWithPast:
|
| 1097 |
+
return super().forward(
|
| 1098 |
+
input_ids=input_ids,
|
| 1099 |
+
attention_mask=attention_mask,
|
| 1100 |
+
position_ids=position_ids,
|
| 1101 |
+
past_key_values=past_key_values,
|
| 1102 |
+
inputs_embeds=inputs_embeds,
|
| 1103 |
+
pixel_values=pixel_values,
|
| 1104 |
+
token_type_ids=token_type_ids,
|
| 1105 |
+
labels=labels,
|
| 1106 |
+
**kwargs,
|
| 1107 |
+
)
|
| 1108 |
+
|
| 1109 |
+
|
| 1110 |
+
__all__ = [
|
| 1111 |
+
"Gemma3PreTrainedModel",
|
| 1112 |
+
"Gemma3TextModel",
|
| 1113 |
+
"Gemma3ForCausalLM",
|
| 1114 |
+
"Gemma3ForConditionalGeneration",
|
| 1115 |
+
"Gemma3Model",
|
| 1116 |
+
"Gemma3ForSequenceClassification",
|
| 1117 |
+
"Gemma3TextForSequenceClassification",
|
| 1118 |
+
]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/modular_gemma3.py
ADDED
|
@@ -0,0 +1,941 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from collections.abc import Callable
|
| 16 |
+
from typing import Any, Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from huggingface_hub.dataclasses import strict
|
| 21 |
+
|
| 22 |
+
from ... import initialization as init
|
| 23 |
+
from ...cache_utils import Cache, DynamicCache
|
| 24 |
+
from ...configuration_utils import PreTrainedConfig
|
| 25 |
+
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
|
| 26 |
+
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
|
| 27 |
+
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, SequenceClassifierOutputWithPast
|
| 28 |
+
from ...modeling_rope_utils import (
|
| 29 |
+
ROPE_INIT_FUNCTIONS,
|
| 30 |
+
dynamic_rope_update,
|
| 31 |
+
)
|
| 32 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 33 |
+
from ...processing_utils import Unpack
|
| 34 |
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
| 35 |
+
from ...utils.generic import maybe_autocast
|
| 36 |
+
from ..gemma2.configuration_gemma2 import Gemma2Config
|
| 37 |
+
from ..gemma2.modeling_gemma2 import (
|
| 38 |
+
Gemma2Attention,
|
| 39 |
+
Gemma2ForCausalLM,
|
| 40 |
+
Gemma2MLP,
|
| 41 |
+
Gemma2Model,
|
| 42 |
+
Gemma2PreTrainedModel,
|
| 43 |
+
Gemma2RMSNorm,
|
| 44 |
+
Gemma2RotaryEmbedding,
|
| 45 |
+
apply_rotary_pos_emb,
|
| 46 |
+
eager_attention_forward,
|
| 47 |
+
)
|
| 48 |
+
from ..paligemma.modeling_paligemma import (
|
| 49 |
+
PaliGemmaCausalLMOutputWithPast,
|
| 50 |
+
PaliGemmaForConditionalGeneration,
|
| 51 |
+
PaliGemmaModel,
|
| 52 |
+
PaligemmaModelOutputWithPast,
|
| 53 |
+
)
|
| 54 |
+
from ..siglip import SiglipVisionConfig
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
logger = logging.get_logger(__name__)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@auto_docstring(checkpoint="google/gemma-3-4b-it")
|
| 61 |
+
@strict
|
| 62 |
+
class Gemma3TextConfig(Gemma2Config, PreTrainedConfig):
|
| 63 |
+
r"""
|
| 64 |
+
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
|
| 65 |
+
scaling factor used on the attention scores
|
| 66 |
+
final_logit_softcapping (`float`, *optional*):
|
| 67 |
+
Scaling factor when applying tanh softcapping on the logits.
|
| 68 |
+
attn_logit_softcapping (`float`, *optional*):
|
| 69 |
+
Scaling factor when applying tanh softcapping on the attention scores.
|
| 70 |
+
use_bidirectional_attention (`bool`, *optional*, defaults to `False`):
|
| 71 |
+
If True, the model will attend to all text tokens instead of using a causal mask. This does not change
|
| 72 |
+
behavior for vision tokens.
|
| 73 |
+
|
| 74 |
+
```python
|
| 75 |
+
>>> from transformers import Gemma3TextModel, Gemma3TextConfig
|
| 76 |
+
>>> # Initializing a Gemma3Text gemma3_text-7b style configuration
|
| 77 |
+
>>> configuration = Gemma3TextConfig()
|
| 78 |
+
>>> # Initializing a model from the gemma3_text-7b style configuration
|
| 79 |
+
>>> model = Gemma3TextModel(configuration)
|
| 80 |
+
>>> # Accessing the model configuration
|
| 81 |
+
>>> configuration = model.config
|
| 82 |
+
```
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
model_type = "gemma3_text"
|
| 86 |
+
base_model_tp_plan = {
|
| 87 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 88 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 89 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 90 |
+
"layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
|
| 91 |
+
"layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
|
| 92 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 93 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 94 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 95 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 96 |
+
}
|
| 97 |
+
default_theta = {"global": 1_000_000.0, "local": 10_000.0}
|
| 98 |
+
|
| 99 |
+
vocab_size: int = 262_208
|
| 100 |
+
max_position_embeddings: int = 131_072
|
| 101 |
+
layer_types: list[str] | None = None
|
| 102 |
+
final_logit_softcapping: float | None = None
|
| 103 |
+
attn_logit_softcapping: float | None = None
|
| 104 |
+
rope_parameters: dict | None = None
|
| 105 |
+
use_bidirectional_attention: bool | None = False
|
| 106 |
+
|
| 107 |
+
def __post_init__(self, **kwargs):
|
| 108 |
+
if self.use_bidirectional_attention:
|
| 109 |
+
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
|
| 110 |
+
|
| 111 |
+
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
|
| 112 |
+
self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
|
| 113 |
+
|
| 114 |
+
if self.layer_types is None:
|
| 115 |
+
self.layer_types = [
|
| 116 |
+
"sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
|
| 117 |
+
for i in range(self.num_hidden_layers)
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
PreTrainedConfig.__post_init__(**kwargs)
|
| 121 |
+
|
| 122 |
+
def convert_rope_params_to_dict(self, **kwargs):
|
| 123 |
+
rope_scaling = kwargs.pop("rope_scaling", None)
|
| 124 |
+
|
| 125 |
+
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`. If we find `rope_parameters`
|
| 126 |
+
# as arg in the inputs, we can safely assume that it is in the new format. New naming used -> new format
|
| 127 |
+
default_rope_params = {
|
| 128 |
+
"sliding_attention": {"rope_type": "default"},
|
| 129 |
+
"full_attention": {"rope_type": "default"},
|
| 130 |
+
}
|
| 131 |
+
self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else default_rope_params
|
| 132 |
+
if rope_scaling is not None:
|
| 133 |
+
self.rope_parameters["full_attention"].update(rope_scaling)
|
| 134 |
+
|
| 135 |
+
# Set default values if not present
|
| 136 |
+
if self.rope_parameters.get("full_attention") is None:
|
| 137 |
+
self.rope_parameters["full_attention"] = {"rope_type": "default"}
|
| 138 |
+
self.rope_parameters["full_attention"].setdefault(
|
| 139 |
+
"rope_theta", kwargs.pop("rope_theta", self.default_theta["global"])
|
| 140 |
+
)
|
| 141 |
+
if self.rope_parameters.get("sliding_attention") is None:
|
| 142 |
+
self.rope_parameters["sliding_attention"] = {"rope_type": "default"}
|
| 143 |
+
self.rope_parameters["sliding_attention"].setdefault(
|
| 144 |
+
"rope_theta", kwargs.pop("rope_local_base_freq", self.default_theta["local"])
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Standardize and validate the correctness of rotary position embeddings parameters
|
| 148 |
+
self.standardize_rope_params()
|
| 149 |
+
return kwargs
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@auto_docstring(checkpoint="google/gemma-3-4b-it")
|
| 153 |
+
@strict
|
| 154 |
+
class Gemma3Config(PreTrainedConfig):
|
| 155 |
+
r"""
|
| 156 |
+
mm_tokens_per_image (`int`, *optional*, defaults to 256):
|
| 157 |
+
The number of tokens per image embedding.
|
| 158 |
+
boi_token_index (`int`, *optional*, defaults to 255999):
|
| 159 |
+
The begin-of-image token index to wrap the image prompt.
|
| 160 |
+
eoi_token_index (`int`, *optional*, defaults to 256000):
|
| 161 |
+
The end-of-image token index to wrap the image prompt.
|
| 162 |
+
|
| 163 |
+
Example:
|
| 164 |
+
|
| 165 |
+
```python
|
| 166 |
+
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
|
| 167 |
+
|
| 168 |
+
>>> # Initializing a Siglip-like vision config
|
| 169 |
+
>>> vision_config = SiglipVisionConfig()
|
| 170 |
+
|
| 171 |
+
>>> # Initializing a Gemma3 Text config
|
| 172 |
+
>>> text_config = Gemma3TextConfig()
|
| 173 |
+
|
| 174 |
+
>>> # Initializing a Gemma3 gemma-3-4b style configuration
|
| 175 |
+
>>> configuration = Gemma3Config(vision_config, text_config)
|
| 176 |
+
|
| 177 |
+
>>> # Initializing a model from the gemma-3-4b style configuration
|
| 178 |
+
>>> model = Gemma3TextConfig(configuration)
|
| 179 |
+
|
| 180 |
+
>>> # Accessing the model configuration
|
| 181 |
+
>>> configuration = model.config
|
| 182 |
+
```"""
|
| 183 |
+
|
| 184 |
+
model_type = "gemma3"
|
| 185 |
+
attribute_map = {
|
| 186 |
+
"image_token_id": "image_token_index",
|
| 187 |
+
"boi_token_id": "boi_token_index",
|
| 188 |
+
"eoi_token_id": "eoi_token_index",
|
| 189 |
+
}
|
| 190 |
+
sub_configs = {
|
| 191 |
+
"text_config": Gemma3TextConfig,
|
| 192 |
+
"vision_config": SiglipVisionConfig,
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
text_config: Gemma3TextConfig | dict[str, Any] | None = None
|
| 196 |
+
vision_config: SiglipVisionConfig | dict[str, Any] | None = None
|
| 197 |
+
mm_tokens_per_image: int | None = 256
|
| 198 |
+
boi_token_index: int | None = 255_999
|
| 199 |
+
eoi_token_index: int | None = 256_000
|
| 200 |
+
image_token_index: int | None = 262_144
|
| 201 |
+
initializer_range: float | None = 0.02
|
| 202 |
+
tie_word_embeddings: bool | None = True
|
| 203 |
+
|
| 204 |
+
def __post_init__(self, **kwargs):
|
| 205 |
+
if self.text_config is None:
|
| 206 |
+
self.text_config = Gemma3TextConfig()
|
| 207 |
+
logger.info("text_config is None, using default Gemma3TextConfig text config.")
|
| 208 |
+
elif isinstance(self.text_config, dict):
|
| 209 |
+
self.text_config = Gemma3TextConfig(**self.text_config)
|
| 210 |
+
|
| 211 |
+
if isinstance(self.vision_config, dict):
|
| 212 |
+
self.vision_config = SiglipVisionConfig(**self.vision_config)
|
| 213 |
+
elif self.vision_config is None:
|
| 214 |
+
self.vision_config = SiglipVisionConfig()
|
| 215 |
+
logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
|
| 216 |
+
|
| 217 |
+
super().__post_init__(**kwargs)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class Gemma3ModelOutputWithPast(PaligemmaModelOutputWithPast):
|
| 221 |
+
pass
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class Gemma3CausalLMOutputWithPast(PaliGemmaCausalLMOutputWithPast):
|
| 225 |
+
pass
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class Gemma3TextScaledWordEmbedding(nn.Embedding):
|
| 229 |
+
"""
|
| 230 |
+
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
| 234 |
+
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
| 235 |
+
self.scalar_embed_scale = embed_scale
|
| 236 |
+
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
| 237 |
+
|
| 238 |
+
def forward(self, input_ids: torch.Tensor):
|
| 239 |
+
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class Gemma3MLP(Gemma2MLP):
|
| 243 |
+
def __init__(self, config: Gemma3TextConfig):
|
| 244 |
+
super().__init__(config)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class Gemma3RMSNorm(Gemma2RMSNorm):
|
| 248 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 249 |
+
super().__init__(dim=dim, eps=eps)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding, nn.Module):
|
| 253 |
+
def __init__(self, config: Gemma3TextConfig):
|
| 254 |
+
nn.Module.__init__(self)
|
| 255 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 256 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 257 |
+
self.config = config
|
| 258 |
+
self.layer_types = list(set(config.layer_types))
|
| 259 |
+
self.rope_type = {}
|
| 260 |
+
for layer_type in self.layer_types:
|
| 261 |
+
rope_params = self.config.rope_parameters[layer_type]
|
| 262 |
+
if rope_params is None:
|
| 263 |
+
continue
|
| 264 |
+
|
| 265 |
+
self.rope_type[layer_type] = rope_params["rope_type"]
|
| 266 |
+
rope_init_fn: Callable = self.compute_default_rope_parameters
|
| 267 |
+
if self.rope_type[layer_type] != "default":
|
| 268 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
|
| 269 |
+
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, layer_type=layer_type)
|
| 270 |
+
self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
|
| 271 |
+
self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
|
| 272 |
+
setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
|
| 273 |
+
|
| 274 |
+
@staticmethod
|
| 275 |
+
def compute_default_rope_parameters(
|
| 276 |
+
config: Gemma3TextConfig | None = None,
|
| 277 |
+
device: Optional["torch.device"] = None,
|
| 278 |
+
seq_len: int | None = None,
|
| 279 |
+
layer_type: str | None = None,
|
| 280 |
+
) -> tuple["torch.Tensor", float]:
|
| 281 |
+
"""
|
| 282 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 283 |
+
Args:
|
| 284 |
+
config ([`~transformers.PreTrainedConfig`]):
|
| 285 |
+
The model configuration.
|
| 286 |
+
device (`torch.device`):
|
| 287 |
+
The device to use for initialization of the inverse frequencies.
|
| 288 |
+
seq_len (`int`, *optional*):
|
| 289 |
+
The current sequence length. Unused for this type of RoPE.
|
| 290 |
+
layer_type (`str`, *optional*):
|
| 291 |
+
The current layer type if the model has different RoPE parameters per type.
|
| 292 |
+
Should not be used unless `config.layer_types is not None`
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 296 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 297 |
+
"""
|
| 298 |
+
# For backward compatibility standardize the `rope_parameters_dict` if it uses old format
|
| 299 |
+
base = config.rope_parameters[layer_type]["rope_theta"]
|
| 300 |
+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 301 |
+
|
| 302 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 303 |
+
|
| 304 |
+
# Compute the inverse frequencies
|
| 305 |
+
inv_freq = 1.0 / (
|
| 306 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
| 307 |
+
)
|
| 308 |
+
return inv_freq, attention_factor
|
| 309 |
+
|
| 310 |
+
@torch.no_grad()
|
| 311 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 312 |
+
def forward(self, x, position_ids, layer_type=None):
|
| 313 |
+
inv_freq = getattr(self, f"{layer_type}_inv_freq")
|
| 314 |
+
attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
|
| 315 |
+
|
| 316 |
+
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 317 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 318 |
+
|
| 319 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 320 |
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
| 321 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 322 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 323 |
+
cos = emb.cos() * attention_scaling
|
| 324 |
+
sin = emb.sin() * attention_scaling
|
| 325 |
+
|
| 326 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# Weird way to inherit but otherwise the sliding window gets defined first and can't access `is_sliding`
|
| 330 |
+
class Gemma3Attention(Gemma2Attention):
|
| 331 |
+
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
|
| 332 |
+
super().__init__(config, layer_idx)
|
| 333 |
+
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
| 334 |
+
self.is_sliding = self.layer_type == "sliding_attention"
|
| 335 |
+
self.is_causal = not self.config.use_bidirectional_attention
|
| 336 |
+
|
| 337 |
+
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
| 338 |
+
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
| 339 |
+
|
| 340 |
+
def forward(
|
| 341 |
+
self,
|
| 342 |
+
hidden_states: torch.Tensor,
|
| 343 |
+
position_embeddings: torch.Tensor = None,
|
| 344 |
+
attention_mask: torch.Tensor | None = None,
|
| 345 |
+
past_key_values: Cache | None = None,
|
| 346 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 347 |
+
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
| 348 |
+
input_shape = hidden_states.shape[:-1]
|
| 349 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 350 |
+
|
| 351 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 352 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 353 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 354 |
+
|
| 355 |
+
query_states = self.q_norm(query_states)
|
| 356 |
+
key_states = self.k_norm(key_states)
|
| 357 |
+
|
| 358 |
+
cos, sin = position_embeddings
|
| 359 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 360 |
+
|
| 361 |
+
if past_key_values is not None:
|
| 362 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
| 363 |
+
|
| 364 |
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
| 365 |
+
self.config._attn_implementation, eager_attention_forward
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
attn_output, attn_weights = attention_interface(
|
| 369 |
+
self,
|
| 370 |
+
query_states,
|
| 371 |
+
key_states,
|
| 372 |
+
value_states,
|
| 373 |
+
attention_mask,
|
| 374 |
+
dropout=self.attention_dropout if self.training else 0.0,
|
| 375 |
+
scaling=self.scaling,
|
| 376 |
+
sliding_window=self.sliding_window,
|
| 377 |
+
**kwargs,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 381 |
+
attn_output = self.o_proj(attn_output)
|
| 382 |
+
return attn_output, attn_weights
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
| 386 |
+
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
|
| 387 |
+
super().__init__()
|
| 388 |
+
self.config = config
|
| 389 |
+
self.hidden_size = config.hidden_size
|
| 390 |
+
self.layer_idx = layer_idx
|
| 391 |
+
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
|
| 392 |
+
self.mlp = Gemma3MLP(config)
|
| 393 |
+
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 394 |
+
self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 395 |
+
self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 396 |
+
self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 397 |
+
|
| 398 |
+
def forward(
|
| 399 |
+
self,
|
| 400 |
+
hidden_states: torch.Tensor,
|
| 401 |
+
position_embeddings: torch.Tensor = None,
|
| 402 |
+
attention_mask: torch.Tensor | None = None,
|
| 403 |
+
position_ids: torch.LongTensor | None = None,
|
| 404 |
+
past_key_values: Cache | None = None,
|
| 405 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 406 |
+
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
| 407 |
+
residual = hidden_states
|
| 408 |
+
|
| 409 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 410 |
+
|
| 411 |
+
hidden_states, _ = self.self_attn(
|
| 412 |
+
hidden_states=hidden_states,
|
| 413 |
+
position_embeddings=position_embeddings,
|
| 414 |
+
attention_mask=attention_mask,
|
| 415 |
+
position_ids=position_ids,
|
| 416 |
+
past_key_values=past_key_values,
|
| 417 |
+
**kwargs,
|
| 418 |
+
)
|
| 419 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 420 |
+
hidden_states = residual + hidden_states
|
| 421 |
+
|
| 422 |
+
residual = hidden_states
|
| 423 |
+
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
| 424 |
+
hidden_states = self.mlp(hidden_states)
|
| 425 |
+
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
| 426 |
+
hidden_states = residual + hidden_states
|
| 427 |
+
|
| 428 |
+
return hidden_states
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
GEMMA3_START_DOCSTRING = None
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
|
| 435 |
+
base_model_prefix = "model"
|
| 436 |
+
input_modalities = ("image", "text")
|
| 437 |
+
_no_split_modules = [
|
| 438 |
+
"Gemma3DecoderLayer",
|
| 439 |
+
"SiglipVisionEmbeddings",
|
| 440 |
+
"SiglipEncoderLayer",
|
| 441 |
+
"SiglipMultiheadAttentionPoolingHead",
|
| 442 |
+
]
|
| 443 |
+
|
| 444 |
+
@torch.no_grad()
|
| 445 |
+
def _init_weights(self, module):
|
| 446 |
+
PreTrainedModel._init_weights(self, module)
|
| 447 |
+
if isinstance(module, Gemma3MultiModalProjector):
|
| 448 |
+
init.zeros_(module.mm_input_projection_weight)
|
| 449 |
+
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
| 450 |
+
elif "RMSNorm" in module.__class__.__name__:
|
| 451 |
+
init.zeros_(module.weight)
|
| 452 |
+
elif isinstance(module, Gemma3TextScaledWordEmbedding):
|
| 453 |
+
init.constant_(module.embed_scale, module.scalar_embed_scale)
|
| 454 |
+
elif isinstance(module, Gemma3RotaryEmbedding):
|
| 455 |
+
for layer_type in module.layer_types:
|
| 456 |
+
rope_init_fn = module.compute_default_rope_parameters
|
| 457 |
+
if module.rope_type[layer_type] != "default":
|
| 458 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
| 459 |
+
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
| 460 |
+
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
| 461 |
+
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
|
| 465 |
+
"""
|
| 466 |
+
Enables a bidirectional mask within the sliding window.
|
| 467 |
+
"""
|
| 468 |
+
|
| 469 |
+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
| 470 |
+
"""A token can attend to any other token if their absolute distance is within
|
| 471 |
+
the (exclusive) sliding window size (distance < sliding_window)."""
|
| 472 |
+
return abs(q_idx - kv_idx) < sliding_window
|
| 473 |
+
|
| 474 |
+
return inner_mask
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
class Gemma3TextModel(Gemma2Model):
|
| 478 |
+
config: Gemma3TextConfig
|
| 479 |
+
input_modalities = ("text",)
|
| 480 |
+
|
| 481 |
+
def __init__(self, config: Gemma3TextConfig):
|
| 482 |
+
super().__init__(config)
|
| 483 |
+
|
| 484 |
+
# Gemma3 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
|
| 485 |
+
self.embed_tokens = Gemma3TextScaledWordEmbedding(
|
| 486 |
+
config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
def forward(
|
| 490 |
+
self,
|
| 491 |
+
input_ids: torch.LongTensor | None = None,
|
| 492 |
+
attention_mask: torch.Tensor | None = None,
|
| 493 |
+
position_ids: torch.LongTensor | None = None,
|
| 494 |
+
past_key_values: Cache | None = None,
|
| 495 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 496 |
+
use_cache: bool | None = None,
|
| 497 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 498 |
+
) -> BaseModelOutputWithPast:
|
| 499 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 500 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 501 |
+
|
| 502 |
+
if inputs_embeds is None:
|
| 503 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 504 |
+
|
| 505 |
+
if use_cache and past_key_values is None:
|
| 506 |
+
past_key_values = DynamicCache(config=self.config)
|
| 507 |
+
|
| 508 |
+
if position_ids is None:
|
| 509 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 510 |
+
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
|
| 511 |
+
position_ids = position_ids.unsqueeze(0)
|
| 512 |
+
|
| 513 |
+
# It may already have been prepared by e.g. `generate`
|
| 514 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 515 |
+
# Prepare mask arguments
|
| 516 |
+
mask_kwargs = {
|
| 517 |
+
"config": self.config,
|
| 518 |
+
"inputs_embeds": inputs_embeds,
|
| 519 |
+
"attention_mask": attention_mask,
|
| 520 |
+
"past_key_values": past_key_values,
|
| 521 |
+
"position_ids": position_ids,
|
| 522 |
+
}
|
| 523 |
+
sliding_mask_kwargs = mask_kwargs.copy()
|
| 524 |
+
|
| 525 |
+
if self.config.use_bidirectional_attention:
|
| 526 |
+
mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
|
| 527 |
+
sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window)
|
| 528 |
+
|
| 529 |
+
# Create the masks
|
| 530 |
+
causal_mask_mapping = {
|
| 531 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
| 532 |
+
"sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
# embed positions
|
| 536 |
+
hidden_states = inputs_embeds
|
| 537 |
+
position_embeddings = {}
|
| 538 |
+
for layer_type in set(self.config.layer_types):
|
| 539 |
+
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
| 540 |
+
|
| 541 |
+
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
|
| 542 |
+
hidden_states = decoder_layer(
|
| 543 |
+
hidden_states,
|
| 544 |
+
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
|
| 545 |
+
position_embeddings=position_embeddings[self.config.layer_types[i]],
|
| 546 |
+
position_ids=position_ids,
|
| 547 |
+
past_key_values=past_key_values,
|
| 548 |
+
**kwargs,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
hidden_states = self.norm(hidden_states)
|
| 552 |
+
|
| 553 |
+
return BaseModelOutputWithPast(
|
| 554 |
+
last_hidden_state=hidden_states,
|
| 555 |
+
past_key_values=past_key_values,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
class Gemma3ForCausalLM(Gemma2ForCausalLM):
|
| 560 |
+
config: Gemma3TextConfig
|
| 561 |
+
|
| 562 |
+
def __init__(self, config: Gemma3TextConfig):
|
| 563 |
+
super().__init__(config)
|
| 564 |
+
self.model = Gemma3TextModel(config)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
class Gemma3MultiModalProjector(nn.Module):
|
| 568 |
+
def __init__(self, config: Gemma3Config):
|
| 569 |
+
super().__init__()
|
| 570 |
+
|
| 571 |
+
self.mm_input_projection_weight = nn.Parameter(
|
| 572 |
+
torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
self.mm_soft_emb_norm = Gemma3RMSNorm(
|
| 576 |
+
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
|
| 580 |
+
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
|
| 581 |
+
self.kernel_size = self.patches_per_image // self.tokens_per_side
|
| 582 |
+
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
|
| 583 |
+
|
| 584 |
+
def forward(self, vision_outputs: torch.Tensor):
|
| 585 |
+
batch_size, _, hidden_size = vision_outputs.shape
|
| 586 |
+
|
| 587 |
+
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
|
| 588 |
+
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
|
| 589 |
+
batch_size, hidden_size, self.patches_per_image, self.patches_per_image
|
| 590 |
+
)
|
| 591 |
+
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
| 592 |
+
|
| 593 |
+
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
| 594 |
+
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
|
| 595 |
+
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
|
| 596 |
+
|
| 597 |
+
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
|
| 598 |
+
|
| 599 |
+
projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
|
| 600 |
+
return projected_vision_outputs.type_as(vision_outputs)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def get_block_sequence_ids_for_mask(token_type_ids: torch.Tensor, device: torch.device | None = None) -> torch.Tensor:
|
| 604 |
+
# First find where a new image block starts: 1 if image and previous not image
|
| 605 |
+
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
| 606 |
+
is_image = (token_type_ids == 1).to(device=device)
|
| 607 |
+
is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
| 608 |
+
new_image_start = is_image & ~is_previous_image
|
| 609 |
+
group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
| 610 |
+
block_sequence_ids = torch.where(is_image, group_ids, -1)
|
| 611 |
+
return block_sequence_ids
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
class Gemma3Model(PaliGemmaModel):
|
| 615 |
+
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
| 616 |
+
accepts_loss_kwargs = False
|
| 617 |
+
|
| 618 |
+
def __init__(self, config: Gemma3Config):
|
| 619 |
+
super().__init__(config)
|
| 620 |
+
del self.text_config_dtype
|
| 621 |
+
|
| 622 |
+
@can_return_tuple
|
| 623 |
+
@auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
|
| 624 |
+
def get_image_features(
|
| 625 |
+
self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
|
| 626 |
+
) -> tuple | BaseModelOutputWithPooling:
|
| 627 |
+
vision_outputs = self.vision_tower(pixel_values=pixel_values, return_dict=True, **kwargs)
|
| 628 |
+
last_hidden_state = vision_outputs.last_hidden_state
|
| 629 |
+
vision_outputs.pooler_output = self.multi_modal_projector(last_hidden_state)
|
| 630 |
+
|
| 631 |
+
return vision_outputs
|
| 632 |
+
|
| 633 |
+
@can_return_tuple
|
| 634 |
+
@auto_docstring
|
| 635 |
+
def forward(
|
| 636 |
+
self,
|
| 637 |
+
input_ids: torch.LongTensor | None = None,
|
| 638 |
+
pixel_values: torch.FloatTensor | None = None,
|
| 639 |
+
attention_mask: torch.Tensor | None = None,
|
| 640 |
+
position_ids: torch.LongTensor | None = None,
|
| 641 |
+
past_key_values: Cache | None = None,
|
| 642 |
+
token_type_ids: torch.LongTensor | None = None,
|
| 643 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 644 |
+
labels: torch.LongTensor | None = None,
|
| 645 |
+
use_cache: bool | None = None,
|
| 646 |
+
**lm_kwargs: Unpack[TransformersKwargs],
|
| 647 |
+
) -> tuple | Gemma3ModelOutputWithPast:
|
| 648 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 649 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 650 |
+
|
| 651 |
+
# Replace image id with PAD if the image token if OOV, to avoid index-errors
|
| 652 |
+
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
| 653 |
+
special_image_mask = input_ids == self.config.image_token_id
|
| 654 |
+
llm_input_ids = input_ids.clone()
|
| 655 |
+
llm_input_ids[special_image_mask] = 0
|
| 656 |
+
else:
|
| 657 |
+
llm_input_ids = input_ids
|
| 658 |
+
|
| 659 |
+
if inputs_embeds is None:
|
| 660 |
+
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
| 661 |
+
|
| 662 |
+
# Merge text and images
|
| 663 |
+
if pixel_values is not None:
|
| 664 |
+
image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
|
| 665 |
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 666 |
+
special_image_mask = self.get_placeholder_mask(
|
| 667 |
+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
| 668 |
+
)
|
| 669 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
| 670 |
+
|
| 671 |
+
# It may already have been prepared by e.g. `generate`
|
| 672 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 673 |
+
mask_kwargs = {
|
| 674 |
+
"config": self.config.get_text_config(),
|
| 675 |
+
"inputs_embeds": inputs_embeds,
|
| 676 |
+
"attention_mask": attention_mask,
|
| 677 |
+
"past_key_values": past_key_values,
|
| 678 |
+
"position_ids": position_ids,
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
if token_type_ids is not None:
|
| 682 |
+
mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask(
|
| 683 |
+
token_type_ids, device=inputs_embeds.device
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
# Create the masks
|
| 687 |
+
sliding_mask_kwargs = mask_kwargs.copy()
|
| 688 |
+
causal_mask_mapping = {
|
| 689 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
| 690 |
+
"sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
|
| 691 |
+
}
|
| 692 |
+
|
| 693 |
+
outputs = self.language_model(
|
| 694 |
+
attention_mask=causal_mask_mapping,
|
| 695 |
+
position_ids=position_ids,
|
| 696 |
+
past_key_values=past_key_values,
|
| 697 |
+
inputs_embeds=inputs_embeds,
|
| 698 |
+
use_cache=use_cache,
|
| 699 |
+
return_dict=True,
|
| 700 |
+
**lm_kwargs,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
return Gemma3ModelOutputWithPast(
|
| 704 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 705 |
+
past_key_values=outputs.past_key_values,
|
| 706 |
+
hidden_states=outputs.hidden_states,
|
| 707 |
+
attentions=outputs.attentions,
|
| 708 |
+
image_hidden_states=image_features if pixel_values is not None else None,
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
| 713 |
+
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
| 714 |
+
# Fix: https://github.com/huggingface/transformers/issues/40564
|
| 715 |
+
accepts_loss_kwargs = False
|
| 716 |
+
|
| 717 |
+
@can_return_tuple
|
| 718 |
+
@auto_docstring
|
| 719 |
+
def forward(
|
| 720 |
+
self,
|
| 721 |
+
input_ids: torch.LongTensor | None = None,
|
| 722 |
+
pixel_values: torch.FloatTensor | None = None,
|
| 723 |
+
attention_mask: torch.Tensor | None = None,
|
| 724 |
+
position_ids: torch.LongTensor | None = None,
|
| 725 |
+
past_key_values: Cache | None = None,
|
| 726 |
+
token_type_ids: torch.LongTensor | None = None,
|
| 727 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 728 |
+
labels: torch.LongTensor | None = None,
|
| 729 |
+
use_cache: bool | None = None,
|
| 730 |
+
logits_to_keep: int | torch.Tensor = 0,
|
| 731 |
+
**lm_kwargs: Unpack[TransformersKwargs],
|
| 732 |
+
) -> tuple | Gemma3CausalLMOutputWithPast:
|
| 733 |
+
r"""
|
| 734 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 735 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 736 |
+
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 737 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
| 738 |
+
|
| 739 |
+
Example:
|
| 740 |
+
|
| 741 |
+
```python
|
| 742 |
+
>>> from PIL import Image
|
| 743 |
+
>>> import httpx
|
| 744 |
+
>>> from io import BytesIO
|
| 745 |
+
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
| 746 |
+
|
| 747 |
+
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
|
| 748 |
+
>>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
| 749 |
+
|
| 750 |
+
>>> messages = [
|
| 751 |
+
... {
|
| 752 |
+
... "role": "system",
|
| 753 |
+
... "content": [
|
| 754 |
+
... {"type": "text", "text": "You are a helpful assistant."}
|
| 755 |
+
... ]
|
| 756 |
+
... },
|
| 757 |
+
... {
|
| 758 |
+
... "role": "user", "content": [
|
| 759 |
+
... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
| 760 |
+
... {"type": "text", "text": "Where is the cat standing?"},
|
| 761 |
+
... ]
|
| 762 |
+
... },
|
| 763 |
+
... ]
|
| 764 |
+
|
| 765 |
+
>>> inputs = processor.apply_chat_template(
|
| 766 |
+
... messages,
|
| 767 |
+
... tokenize=True,
|
| 768 |
+
... return_dict=True,
|
| 769 |
+
... return_tensors="pt",
|
| 770 |
+
... add_generation_prompt=True
|
| 771 |
+
... )
|
| 772 |
+
>>> # Generate
|
| 773 |
+
>>> generate_ids = model.generate(**inputs)
|
| 774 |
+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 775 |
+
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
|
| 776 |
+
```
|
| 777 |
+
"""
|
| 778 |
+
outputs = self.model(
|
| 779 |
+
input_ids=input_ids,
|
| 780 |
+
pixel_values=pixel_values,
|
| 781 |
+
token_type_ids=token_type_ids,
|
| 782 |
+
attention_mask=attention_mask,
|
| 783 |
+
position_ids=position_ids,
|
| 784 |
+
past_key_values=past_key_values,
|
| 785 |
+
inputs_embeds=inputs_embeds,
|
| 786 |
+
use_cache=use_cache,
|
| 787 |
+
labels=labels,
|
| 788 |
+
return_dict=True,
|
| 789 |
+
**lm_kwargs,
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
hidden_states = outputs[0]
|
| 793 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 794 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 795 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 796 |
+
|
| 797 |
+
loss = None
|
| 798 |
+
if labels is not None:
|
| 799 |
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
| 800 |
+
logits = logits.float()
|
| 801 |
+
shift_logits = logits[..., :-1, :]
|
| 802 |
+
shift_labels = labels[..., 1:]
|
| 803 |
+
if attention_mask is not None:
|
| 804 |
+
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
| 805 |
+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
| 806 |
+
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
|
| 807 |
+
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
| 808 |
+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
| 809 |
+
else:
|
| 810 |
+
shift_logits = shift_logits.contiguous()
|
| 811 |
+
shift_labels = shift_labels.contiguous()
|
| 812 |
+
# Flatten the tokens
|
| 813 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 814 |
+
|
| 815 |
+
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
| 816 |
+
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
| 817 |
+
loss = loss_fct(flat_logits, flat_labels)
|
| 818 |
+
|
| 819 |
+
return Gemma3CausalLMOutputWithPast(
|
| 820 |
+
loss=loss,
|
| 821 |
+
logits=logits,
|
| 822 |
+
past_key_values=outputs.past_key_values,
|
| 823 |
+
hidden_states=outputs.hidden_states,
|
| 824 |
+
attentions=outputs.attentions,
|
| 825 |
+
image_hidden_states=outputs.image_hidden_states,
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
def prepare_inputs_for_generation(
|
| 829 |
+
self,
|
| 830 |
+
input_ids,
|
| 831 |
+
past_key_values=None,
|
| 832 |
+
inputs_embeds=None,
|
| 833 |
+
position_ids=None,
|
| 834 |
+
pixel_values=None,
|
| 835 |
+
attention_mask=None,
|
| 836 |
+
token_type_ids=None,
|
| 837 |
+
use_cache=True,
|
| 838 |
+
logits_to_keep=None,
|
| 839 |
+
labels=None,
|
| 840 |
+
is_first_iteration=False,
|
| 841 |
+
**kwargs,
|
| 842 |
+
):
|
| 843 |
+
# Overwritten -- custom `pixel_values` handling
|
| 844 |
+
model_inputs = super().prepare_inputs_for_generation(
|
| 845 |
+
input_ids,
|
| 846 |
+
past_key_values=past_key_values,
|
| 847 |
+
inputs_embeds=inputs_embeds,
|
| 848 |
+
attention_mask=attention_mask,
|
| 849 |
+
position_ids=position_ids,
|
| 850 |
+
use_cache=use_cache,
|
| 851 |
+
logits_to_keep=logits_to_keep,
|
| 852 |
+
token_type_ids=token_type_ids,
|
| 853 |
+
is_first_iteration=is_first_iteration,
|
| 854 |
+
**kwargs,
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
# Pixel values are used only in the first iteration if available
|
| 858 |
+
# In subsequent iterations, they are already merged with text and cached
|
| 859 |
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
| 860 |
+
# iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
|
| 861 |
+
if is_first_iteration or not use_cache:
|
| 862 |
+
model_inputs["pixel_values"] = pixel_values
|
| 863 |
+
else:
|
| 864 |
+
# Don't pass to not apply bidirectional mask on top
|
| 865 |
+
model_inputs["token_type_ids"] = None
|
| 866 |
+
|
| 867 |
+
return model_inputs
|
| 868 |
+
|
| 869 |
+
def create_masks_for_generate(
|
| 870 |
+
config: PreTrainedConfig,
|
| 871 |
+
inputs_embeds: torch.Tensor,
|
| 872 |
+
attention_mask: torch.Tensor | None,
|
| 873 |
+
past_key_values: Cache | None,
|
| 874 |
+
position_ids: torch.Tensor | None,
|
| 875 |
+
token_type_ids: torch.Tensor | None = None,
|
| 876 |
+
is_first_iteration: bool | None = False,
|
| 877 |
+
**kwargs,
|
| 878 |
+
) -> dict:
|
| 879 |
+
mask_kwargs = {
|
| 880 |
+
"config": config.get_text_config(),
|
| 881 |
+
"inputs_embeds": inputs_embeds,
|
| 882 |
+
"attention_mask": attention_mask,
|
| 883 |
+
"past_key_values": past_key_values,
|
| 884 |
+
"position_ids": position_ids,
|
| 885 |
+
}
|
| 886 |
+
|
| 887 |
+
if token_type_ids is not None:
|
| 888 |
+
mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask(
|
| 889 |
+
token_type_ids, device=inputs_embeds.device
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
return create_masks_for_generate(**mask_kwargs)
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
@auto_docstring(
|
| 896 |
+
custom_intro="""
|
| 897 |
+
Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig.
|
| 898 |
+
It uses the generic sequence classification implementation for efficiency and consistency."""
|
| 899 |
+
)
|
| 900 |
+
class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
|
| 901 |
+
config: Gemma3TextConfig
|
| 902 |
+
input_modalities = ("text",)
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
class Gemma3ForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
|
| 906 |
+
def forward(
|
| 907 |
+
self,
|
| 908 |
+
input_ids: torch.LongTensor | None = None,
|
| 909 |
+
pixel_values: torch.FloatTensor | None = None,
|
| 910 |
+
attention_mask: torch.Tensor | None = None,
|
| 911 |
+
position_ids: torch.LongTensor | None = None,
|
| 912 |
+
past_key_values: Cache | None = None,
|
| 913 |
+
token_type_ids: torch.LongTensor | None = None,
|
| 914 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 915 |
+
labels: torch.LongTensor | None = None,
|
| 916 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 917 |
+
) -> SequenceClassifierOutputWithPast:
|
| 918 |
+
return super().forward(
|
| 919 |
+
input_ids=input_ids,
|
| 920 |
+
attention_mask=attention_mask,
|
| 921 |
+
position_ids=position_ids,
|
| 922 |
+
past_key_values=past_key_values,
|
| 923 |
+
inputs_embeds=inputs_embeds,
|
| 924 |
+
pixel_values=pixel_values,
|
| 925 |
+
token_type_ids=token_type_ids,
|
| 926 |
+
labels=labels,
|
| 927 |
+
**kwargs,
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
__all__ = [
|
| 932 |
+
"Gemma3Config",
|
| 933 |
+
"Gemma3TextConfig",
|
| 934 |
+
"Gemma3PreTrainedModel",
|
| 935 |
+
"Gemma3TextModel",
|
| 936 |
+
"Gemma3ForCausalLM",
|
| 937 |
+
"Gemma3ForConditionalGeneration",
|
| 938 |
+
"Gemma3Model",
|
| 939 |
+
"Gemma3ForSequenceClassification",
|
| 940 |
+
"Gemma3TextForSequenceClassification",
|
| 941 |
+
]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/processing_gemma3.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import re
|
| 16 |
+
|
| 17 |
+
from ...feature_extraction_utils import BatchFeature
|
| 18 |
+
from ...image_utils import ImageInput, make_nested_list_of_images
|
| 19 |
+
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
| 20 |
+
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
| 21 |
+
from ...utils import auto_docstring, to_py_obj
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
|
| 25 |
+
_defaults = {
|
| 26 |
+
"text_kwargs": {
|
| 27 |
+
"padding": False,
|
| 28 |
+
"return_mm_token_type_ids": True,
|
| 29 |
+
},
|
| 30 |
+
"images_kwargs": {
|
| 31 |
+
"do_convert_rgb": True,
|
| 32 |
+
"do_pan_and_scan": False,
|
| 33 |
+
"pan_and_scan_min_crop_size": 256,
|
| 34 |
+
"pan_and_scan_max_num_crops": 4,
|
| 35 |
+
"pan_and_scan_min_ratio_to_activate": 1.2,
|
| 36 |
+
},
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@auto_docstring
|
| 41 |
+
class Gemma3Processor(ProcessorMixin):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
image_processor,
|
| 45 |
+
tokenizer,
|
| 46 |
+
chat_template=None,
|
| 47 |
+
image_seq_length: int = 256,
|
| 48 |
+
**kwargs,
|
| 49 |
+
):
|
| 50 |
+
self.image_seq_length = image_seq_length
|
| 51 |
+
self.image_token_id = tokenizer.image_token_id
|
| 52 |
+
self.boi_token = tokenizer.boi_token
|
| 53 |
+
self.image_token = tokenizer.image_token
|
| 54 |
+
image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
|
| 55 |
+
self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
|
| 56 |
+
|
| 57 |
+
super().__init__(
|
| 58 |
+
image_processor=image_processor,
|
| 59 |
+
tokenizer=tokenizer,
|
| 60 |
+
chat_template=chat_template,
|
| 61 |
+
**kwargs,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
@auto_docstring
|
| 65 |
+
def __call__(
|
| 66 |
+
self,
|
| 67 |
+
images: ImageInput | None = None,
|
| 68 |
+
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
| 69 |
+
**kwargs: Unpack[Gemma3ProcessorKwargs],
|
| 70 |
+
) -> BatchFeature:
|
| 71 |
+
if text is None and images is None:
|
| 72 |
+
raise ValueError("Provide at least one of `text` or `images`.")
|
| 73 |
+
|
| 74 |
+
output_kwargs = self._merge_kwargs(
|
| 75 |
+
Gemma3ProcessorKwargs,
|
| 76 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 77 |
+
**kwargs,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
if isinstance(text, str):
|
| 81 |
+
text = [text]
|
| 82 |
+
elif not isinstance(text, list) and not isinstance(text[0], str):
|
| 83 |
+
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
|
| 84 |
+
|
| 85 |
+
image_inputs = {}
|
| 86 |
+
if images is not None:
|
| 87 |
+
images = self.image_processor.fetch_images(images)
|
| 88 |
+
batched_images = make_nested_list_of_images(images)
|
| 89 |
+
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
| 90 |
+
|
| 91 |
+
# Create empty text to be replaced with placeholders
|
| 92 |
+
if not text:
|
| 93 |
+
text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
|
| 94 |
+
|
| 95 |
+
if len(batched_images) != len(text):
|
| 96 |
+
raise ValueError(
|
| 97 |
+
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Replace image tokens by the full expanded sequence
|
| 101 |
+
num_crops = to_py_obj(image_inputs.pop("num_crops"))
|
| 102 |
+
batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
|
| 103 |
+
for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
|
| 104 |
+
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
|
| 105 |
+
|
| 106 |
+
if len(images) != len(image_indexes):
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images."
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Insert additional image tokens for Pan-and-Scan crops
|
| 112 |
+
for num, idx in reversed(list(zip(num_crops, image_indexes))):
|
| 113 |
+
if num:
|
| 114 |
+
formatted_image_text = (
|
| 115 |
+
f"Here is the original image {self.boi_token} and here are some crops to help you see better "
|
| 116 |
+
+ " ".join([self.boi_token] * num)
|
| 117 |
+
)
|
| 118 |
+
prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
|
| 119 |
+
text[batch_idx] = prompt
|
| 120 |
+
|
| 121 |
+
# Expand placeholder image tokens to the full image token sequence
|
| 122 |
+
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
|
| 123 |
+
|
| 124 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 125 |
+
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
| 126 |
+
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
| 127 |
+
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
|
| 128 |
+
|
| 129 |
+
if return_mm_token_type_ids:
|
| 130 |
+
text_inputs["token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
|
| 131 |
+
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
| 132 |
+
|
| 133 |
+
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
| 134 |
+
"""
|
| 135 |
+
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
image_sizes (`list[list[int]]`, *optional*):
|
| 139 |
+
The input sizes formatted as (height, width) per each image.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
| 143 |
+
input modalities, along with other useful data.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
vision_data = {}
|
| 147 |
+
if image_sizes is not None:
|
| 148 |
+
# NOTE: no image cropping supported yet
|
| 149 |
+
num_image_tokens = [self.image_seq_length] * len(image_sizes)
|
| 150 |
+
num_image_patches = [1] * len(image_sizes)
|
| 151 |
+
|
| 152 |
+
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
| 153 |
+
|
| 154 |
+
return MultiModalData(**vision_data)
|
| 155 |
+
|
| 156 |
+
@property
|
| 157 |
+
def model_input_names(self):
|
| 158 |
+
tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
|
| 159 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 160 |
+
|
| 161 |
+
image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
|
| 162 |
+
return list(tokenizer_input_names + image_processor_input_names)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
__all__ = ["Gemma3Processor"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_youtu import *
|
| 22 |
+
from .modeling_youtu import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/configuration_youtu.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/youtu/modular_youtu.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_youtu.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# Copyright 2026 the Tencent and HuggingFace Inc. team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 10 |
+
# and OPT implementations in this library. It has been modified from its
|
| 11 |
+
# original forms to accommodate minor architectural differences compared
|
| 12 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 13 |
+
#
|
| 14 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 15 |
+
# you may not use this file except in compliance with the License.
|
| 16 |
+
# You may obtain a copy of the License at
|
| 17 |
+
#
|
| 18 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 19 |
+
#
|
| 20 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 21 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 22 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 23 |
+
# See the License for the specific language governing permissions and
|
| 24 |
+
# limitations under the License.
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
from huggingface_hub.dataclasses import strict
|
| 28 |
+
|
| 29 |
+
from ...configuration_utils import PreTrainedConfig
|
| 30 |
+
from ...modeling_rope_utils import RopeParameters
|
| 31 |
+
from ...utils import auto_docstring
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@auto_docstring(checkpoint="tencent/Youtu-LLM-2B")
|
| 35 |
+
@strict
|
| 36 |
+
class YoutuConfig(PreTrainedConfig):
|
| 37 |
+
r"""
|
| 38 |
+
rope_interleave (`bool`, *optional*, defaults to `True`):
|
| 39 |
+
Whether to interleave the rotary position embeddings.
|
| 40 |
+
embedding_initializer_range (`float`, *optional*):
|
| 41 |
+
The standard deviation of the truncated_normal_initializer for initializing all embedding matrices.
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
>>> from transformers import YoutuModel, YoutuConfig
|
| 45 |
+
>>> # Initializing a Youtu-LLM-2B style configuration
|
| 46 |
+
>>> configuration = YoutuConfig()
|
| 47 |
+
>>> # Accessing the model configuration
|
| 48 |
+
>>> configuration = model.config
|
| 49 |
+
```"""
|
| 50 |
+
|
| 51 |
+
model_type = "youtu"
|
| 52 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 53 |
+
base_model_tp_plan = {
|
| 54 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 55 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 56 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 57 |
+
}
|
| 58 |
+
base_model_pp_plan = {
|
| 59 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 60 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 61 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 62 |
+
}
|
| 63 |
+
attribute_map = {}
|
| 64 |
+
|
| 65 |
+
vocab_size: int = 128256
|
| 66 |
+
hidden_size: int = 2048
|
| 67 |
+
intermediate_size: int = 6144
|
| 68 |
+
num_hidden_layers: int = 32
|
| 69 |
+
num_attention_heads: int = 16
|
| 70 |
+
num_key_value_heads: int = 16
|
| 71 |
+
kv_lora_rank: int = 512
|
| 72 |
+
q_lora_rank: int | None = 1536
|
| 73 |
+
qk_rope_head_dim: int = 64
|
| 74 |
+
v_head_dim: int | None = 128
|
| 75 |
+
qk_nope_head_dim: int = 128
|
| 76 |
+
hidden_act: str = "silu"
|
| 77 |
+
max_position_embeddings: int = 131072
|
| 78 |
+
initializer_range: float | None = None
|
| 79 |
+
rms_norm_eps: float = 1e-6
|
| 80 |
+
use_cache: bool = True
|
| 81 |
+
pad_token_id: int | None = None
|
| 82 |
+
bos_token_id: int | None = 128000
|
| 83 |
+
eos_token_id: int | list[int] | None = 128001
|
| 84 |
+
tie_word_embeddings: bool = True
|
| 85 |
+
rope_parameters: RopeParameters | dict | None = None
|
| 86 |
+
rope_interleave: bool | None = True
|
| 87 |
+
attention_bias: bool = False
|
| 88 |
+
attention_dropout: float | int | None = 0.0
|
| 89 |
+
embedding_initializer_range: float | None = None
|
| 90 |
+
|
| 91 |
+
def __post_init__(self, **kwargs):
|
| 92 |
+
if self.initializer_range is None:
|
| 93 |
+
if self.hidden_size != 0:
|
| 94 |
+
self.initializer_range = 2.0 / (5.0 * self.hidden_size) ** 0.5
|
| 95 |
+
else:
|
| 96 |
+
self.initializer_range = 0.02
|
| 97 |
+
|
| 98 |
+
self.embedding_initializer_range = self.embedding_initializer_range or 2.0 * self.initializer_range
|
| 99 |
+
if self.num_key_value_heads is None:
|
| 100 |
+
self.num_key_value_heads = self.num_attention_heads
|
| 101 |
+
|
| 102 |
+
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
|
| 103 |
+
self.head_dim = self.qk_rope_head_dim
|
| 104 |
+
super().__post_init__(**kwargs)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
__all__ = ["YoutuConfig"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/modeling_youtu.py
ADDED
|
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/youtu/modular_youtu.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_youtu.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# Copyright 2026 the Tencent and HuggingFace Inc. team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 10 |
+
# and OPT implementations in this library. It has been modified from its
|
| 11 |
+
# original forms to accommodate minor architectural differences compared
|
| 12 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 13 |
+
#
|
| 14 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 15 |
+
# you may not use this file except in compliance with the License.
|
| 16 |
+
# You may obtain a copy of the License at
|
| 17 |
+
#
|
| 18 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 19 |
+
#
|
| 20 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 21 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 22 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 23 |
+
# See the License for the specific language governing permissions and
|
| 24 |
+
# limitations under the License.
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
import math
|
| 28 |
+
from collections.abc import Callable
|
| 29 |
+
from typing import Optional
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
from torch import nn
|
| 34 |
+
|
| 35 |
+
from ... import initialization as init
|
| 36 |
+
from ...activations import ACT2FN
|
| 37 |
+
from ...cache_utils import Cache, DynamicCache
|
| 38 |
+
from ...generation import GenerationMixin
|
| 39 |
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
| 40 |
+
from ...masking_utils import create_causal_mask
|
| 41 |
+
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
| 42 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 43 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 44 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 45 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 46 |
+
from ...processing_utils import Unpack
|
| 47 |
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
| 48 |
+
from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults
|
| 49 |
+
from ...utils.output_capturing import capture_outputs
|
| 50 |
+
from .configuration_youtu import YoutuConfig
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 54 |
+
class YoutuRMSNorm(nn.Module):
|
| 55 |
+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
|
| 56 |
+
"""
|
| 57 |
+
YoutuRMSNorm is equivalent to T5LayerNorm
|
| 58 |
+
"""
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 61 |
+
self.variance_epsilon = eps
|
| 62 |
+
|
| 63 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
input_dtype = hidden_states.dtype
|
| 65 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 66 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 67 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 68 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 69 |
+
|
| 70 |
+
def extra_repr(self):
|
| 71 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class YoutuRotaryEmbedding(nn.Module):
|
| 75 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 76 |
+
|
| 77 |
+
def __init__(self, config: YoutuConfig, device=None):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 80 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 81 |
+
|
| 82 |
+
self.config = config
|
| 83 |
+
|
| 84 |
+
self.rope_type = self.config.rope_parameters["rope_type"]
|
| 85 |
+
rope_init_fn: Callable = self.compute_default_rope_parameters
|
| 86 |
+
if self.rope_type != "default":
|
| 87 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 88 |
+
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
| 89 |
+
|
| 90 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 91 |
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def compute_default_rope_parameters(
|
| 95 |
+
config: YoutuConfig | None = None,
|
| 96 |
+
device: Optional["torch.device"] = None,
|
| 97 |
+
seq_len: int | None = None,
|
| 98 |
+
) -> tuple["torch.Tensor", float]:
|
| 99 |
+
"""
|
| 100 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 101 |
+
Args:
|
| 102 |
+
config ([`~transformers.PreTrainedConfig`]):
|
| 103 |
+
The model configuration.
|
| 104 |
+
device (`torch.device`):
|
| 105 |
+
The device to use for initialization of the inverse frequencies.
|
| 106 |
+
seq_len (`int`, *optional*):
|
| 107 |
+
The current sequence length. Unused for this type of RoPE.
|
| 108 |
+
Returns:
|
| 109 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 110 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 111 |
+
"""
|
| 112 |
+
base = config.rope_parameters["rope_theta"]
|
| 113 |
+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 114 |
+
|
| 115 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 116 |
+
|
| 117 |
+
# Compute the inverse frequencies
|
| 118 |
+
inv_freq = 1.0 / (
|
| 119 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
| 120 |
+
)
|
| 121 |
+
return inv_freq, attention_factor
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 125 |
+
def forward(self, x, position_ids):
|
| 126 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 127 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 128 |
+
|
| 129 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 130 |
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
| 131 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 132 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 133 |
+
cos = emb.cos() * self.attention_scaling
|
| 134 |
+
sin = emb.sin() * self.attention_scaling
|
| 135 |
+
|
| 136 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class YoutuMLP(nn.Module):
|
| 140 |
+
def __init__(self, config):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.config = config
|
| 143 |
+
self.hidden_size = config.hidden_size
|
| 144 |
+
self.intermediate_size = config.intermediate_size
|
| 145 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 146 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 147 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 148 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 152 |
+
return down_proj
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def rotate_half(x):
|
| 156 |
+
"""Rotates half the hidden dims of the input."""
|
| 157 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 158 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 159 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@use_kernel_func_from_hub("rotary_pos_emb")
|
| 163 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
| 164 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
q (`torch.Tensor`): The query tensor.
|
| 168 |
+
k (`torch.Tensor`): The key tensor.
|
| 169 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 170 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 171 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 172 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 173 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 174 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 175 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 176 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 177 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 178 |
+
Returns:
|
| 179 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 180 |
+
"""
|
| 181 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 182 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 183 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 184 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 185 |
+
return q_embed, k_embed
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 189 |
+
"""
|
| 190 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 191 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 192 |
+
"""
|
| 193 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 194 |
+
if n_rep == 1:
|
| 195 |
+
return hidden_states
|
| 196 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 197 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def eager_attention_forward(
|
| 201 |
+
module: nn.Module,
|
| 202 |
+
query: torch.Tensor,
|
| 203 |
+
key: torch.Tensor,
|
| 204 |
+
value: torch.Tensor,
|
| 205 |
+
attention_mask: torch.Tensor | None,
|
| 206 |
+
scaling: float,
|
| 207 |
+
dropout: float = 0.0,
|
| 208 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 209 |
+
):
|
| 210 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 211 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 212 |
+
|
| 213 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 214 |
+
if attention_mask is not None:
|
| 215 |
+
attn_weights = attn_weights + attention_mask
|
| 216 |
+
|
| 217 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 218 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 219 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 220 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 221 |
+
|
| 222 |
+
return attn_output, attn_weights
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 226 |
+
r"""
|
| 227 |
+
TODO let's just use the original freqcis computation to not have the view
|
| 228 |
+
transpose + reshape! This is not optimized!
|
| 229 |
+
Applies Rotary Position Embedding to the query and key tensors.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
q (`torch.Tensor`): The query tensor.
|
| 233 |
+
k (`torch.Tensor`): The key tensor.
|
| 234 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 235 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 236 |
+
position_ids (`torch.Tensor`):
|
| 237 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 238 |
+
used to pass offsetted position ids when working with a KV-cache.
|
| 239 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 240 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 241 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 242 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 243 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 244 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 245 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 246 |
+
Returns:
|
| 247 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 248 |
+
"""
|
| 249 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 250 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 251 |
+
|
| 252 |
+
b, h, s, d = q.shape
|
| 253 |
+
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
| 254 |
+
|
| 255 |
+
b, h, s, d = k.shape
|
| 256 |
+
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
| 257 |
+
|
| 258 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 259 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 260 |
+
return q_embed, k_embed
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def yarn_get_mscale(scale=1, mscale=1):
|
| 264 |
+
if scale <= 1:
|
| 265 |
+
return 1.0
|
| 266 |
+
return 0.1 * mscale * math.log(scale) + 1.0
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class YoutuAttention(nn.Module):
|
| 270 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 271 |
+
|
| 272 |
+
def __init__(self, config: YoutuConfig, layer_idx: int):
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.config = config
|
| 275 |
+
self.layer_idx = layer_idx
|
| 276 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 277 |
+
self.attention_dropout = config.attention_dropout
|
| 278 |
+
self.num_heads = config.num_attention_heads
|
| 279 |
+
|
| 280 |
+
self.q_lora_rank = config.q_lora_rank
|
| 281 |
+
self.qk_rope_head_dim = config.qk_rope_head_dim
|
| 282 |
+
self.kv_lora_rank = config.kv_lora_rank
|
| 283 |
+
self.v_head_dim = config.v_head_dim
|
| 284 |
+
self.qk_nope_head_dim = config.qk_nope_head_dim
|
| 285 |
+
self.qk_head_dim = config.qk_head_dim
|
| 286 |
+
|
| 287 |
+
self.is_causal = True
|
| 288 |
+
if self.q_lora_rank is None:
|
| 289 |
+
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
|
| 290 |
+
else:
|
| 291 |
+
self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
| 292 |
+
self.q_a_layernorm = YoutuRMSNorm(config.q_lora_rank)
|
| 293 |
+
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
|
| 294 |
+
|
| 295 |
+
self.kv_a_proj_with_mqa = nn.Linear(
|
| 296 |
+
config.hidden_size,
|
| 297 |
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
| 298 |
+
bias=config.attention_bias,
|
| 299 |
+
)
|
| 300 |
+
self.kv_a_layernorm = YoutuRMSNorm(self.kv_lora_rank)
|
| 301 |
+
self.kv_b_proj = nn.Linear(
|
| 302 |
+
self.kv_lora_rank,
|
| 303 |
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
| 304 |
+
bias=False,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
self.o_proj = nn.Linear(
|
| 308 |
+
self.num_heads * self.v_head_dim,
|
| 309 |
+
config.hidden_size,
|
| 310 |
+
bias=config.attention_bias,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
self.scaling = self.qk_head_dim ** (-0.5)
|
| 314 |
+
if self.config.rope_parameters.get("rope_type", "default") != "default":
|
| 315 |
+
mscale_all_dim = self.config.rope_parameters.get("mscale_all_dim", 0)
|
| 316 |
+
scaling_factor = self.config.rope_parameters["factor"]
|
| 317 |
+
if mscale_all_dim:
|
| 318 |
+
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
| 319 |
+
self.scaling = self.scaling * mscale * mscale
|
| 320 |
+
|
| 321 |
+
def forward(
|
| 322 |
+
self,
|
| 323 |
+
hidden_states: torch.Tensor,
|
| 324 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 325 |
+
attention_mask: torch.Tensor | None,
|
| 326 |
+
past_key_values: Cache | None = None,
|
| 327 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 328 |
+
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
| 329 |
+
batch_size, seq_length = hidden_states.shape[:-1]
|
| 330 |
+
query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
|
| 331 |
+
key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
|
| 332 |
+
|
| 333 |
+
if self.q_lora_rank is None:
|
| 334 |
+
q_states = self.q_proj(hidden_states)
|
| 335 |
+
else:
|
| 336 |
+
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
| 337 |
+
q_states = q_states.view(query_shape).transpose(1, 2)
|
| 338 |
+
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
| 339 |
+
|
| 340 |
+
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
| 341 |
+
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 342 |
+
|
| 343 |
+
k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
|
| 344 |
+
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 345 |
+
|
| 346 |
+
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
|
| 347 |
+
|
| 348 |
+
cos, sin = position_embeddings
|
| 349 |
+
if self.config.rope_interleave: # support using interleaved weights for efficiency
|
| 350 |
+
q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
|
| 351 |
+
else:
|
| 352 |
+
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
|
| 353 |
+
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
|
| 354 |
+
|
| 355 |
+
query_states = torch.cat((q_pass, q_rot), dim=-1)
|
| 356 |
+
key_states = torch.cat((k_pass, k_rot), dim=-1)
|
| 357 |
+
|
| 358 |
+
if past_key_values is not None:
|
| 359 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
| 360 |
+
|
| 361 |
+
if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
|
| 362 |
+
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
|
| 363 |
+
|
| 364 |
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
| 365 |
+
self.config._attn_implementation, eager_attention_forward
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
attn_output, attn_weights = attention_interface(
|
| 369 |
+
self,
|
| 370 |
+
query_states,
|
| 371 |
+
key_states,
|
| 372 |
+
value_states,
|
| 373 |
+
attention_mask,
|
| 374 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 375 |
+
scaling=self.scaling,
|
| 376 |
+
**kwargs,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
|
| 380 |
+
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
| 381 |
+
|
| 382 |
+
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
|
| 383 |
+
attn_output = self.o_proj(attn_output)
|
| 384 |
+
return attn_output, attn_weights
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class YoutuDecoderLayer(GradientCheckpointingLayer):
|
| 388 |
+
def __init__(self, config: YoutuConfig, layer_idx: int):
|
| 389 |
+
super().__init__()
|
| 390 |
+
self.hidden_size = config.hidden_size
|
| 391 |
+
|
| 392 |
+
self.self_attn = YoutuAttention(config=config, layer_idx=layer_idx)
|
| 393 |
+
|
| 394 |
+
self.mlp = YoutuMLP(config)
|
| 395 |
+
self.input_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 396 |
+
self.post_attention_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 397 |
+
|
| 398 |
+
def forward(
|
| 399 |
+
self,
|
| 400 |
+
hidden_states: torch.Tensor,
|
| 401 |
+
attention_mask: torch.Tensor | None = None,
|
| 402 |
+
position_ids: torch.LongTensor | None = None,
|
| 403 |
+
past_key_values: Cache | None = None,
|
| 404 |
+
use_cache: bool | None = False,
|
| 405 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 406 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 407 |
+
) -> torch.Tensor:
|
| 408 |
+
residual = hidden_states
|
| 409 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 410 |
+
# Self Attention
|
| 411 |
+
hidden_states, _ = self.self_attn(
|
| 412 |
+
hidden_states=hidden_states,
|
| 413 |
+
attention_mask=attention_mask,
|
| 414 |
+
position_ids=position_ids,
|
| 415 |
+
past_key_values=past_key_values,
|
| 416 |
+
use_cache=use_cache,
|
| 417 |
+
position_embeddings=position_embeddings,
|
| 418 |
+
**kwargs,
|
| 419 |
+
)
|
| 420 |
+
hidden_states = residual + hidden_states
|
| 421 |
+
|
| 422 |
+
# Fully Connected
|
| 423 |
+
residual = hidden_states
|
| 424 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 425 |
+
hidden_states = self.mlp(hidden_states)
|
| 426 |
+
hidden_states = residual + hidden_states
|
| 427 |
+
return hidden_states
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
@auto_docstring
|
| 431 |
+
class YoutuPreTrainedModel(PreTrainedModel):
|
| 432 |
+
config: YoutuConfig
|
| 433 |
+
base_model_prefix = "model"
|
| 434 |
+
supports_gradient_checkpointing = True
|
| 435 |
+
_no_split_modules = ["YoutuDecoderLayer"]
|
| 436 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 437 |
+
_supports_flash_attn = True
|
| 438 |
+
_supports_sdpa = True
|
| 439 |
+
_supports_flex_attn = True
|
| 440 |
+
|
| 441 |
+
_can_compile_fullgraph = True
|
| 442 |
+
_supports_attention_backend = True
|
| 443 |
+
_can_record_outputs = {
|
| 444 |
+
"hidden_states": YoutuDecoderLayer,
|
| 445 |
+
"attentions": YoutuAttention,
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
@torch.no_grad()
|
| 449 |
+
def _init_weights(self, module):
|
| 450 |
+
super()._init_weights(module)
|
| 451 |
+
std = getattr(self.config, "initializer_range", 0.02)
|
| 452 |
+
embed_std = getattr(self.config, "embedding_initializer_range", 2 * std)
|
| 453 |
+
if isinstance(module, nn.Embedding):
|
| 454 |
+
init.normal_(module.weight, mean=0.0, std=embed_std)
|
| 455 |
+
if module.padding_idx is not None:
|
| 456 |
+
init.zeros_(module.weight.data[module.padding_idx])
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
@auto_docstring
|
| 460 |
+
class YoutuModel(YoutuPreTrainedModel):
|
| 461 |
+
def __init__(self, config: YoutuConfig):
|
| 462 |
+
super().__init__(config)
|
| 463 |
+
self.padding_idx = config.pad_token_id
|
| 464 |
+
self.vocab_size = config.vocab_size
|
| 465 |
+
|
| 466 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 467 |
+
self.layers = nn.ModuleList(
|
| 468 |
+
[YoutuDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 469 |
+
)
|
| 470 |
+
self.norm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 471 |
+
self.rotary_emb = YoutuRotaryEmbedding(config=config)
|
| 472 |
+
self.gradient_checkpointing = False
|
| 473 |
+
|
| 474 |
+
# Initialize weights and apply final processing
|
| 475 |
+
self.post_init()
|
| 476 |
+
|
| 477 |
+
@merge_with_config_defaults
|
| 478 |
+
@capture_outputs
|
| 479 |
+
@auto_docstring
|
| 480 |
+
def forward(
|
| 481 |
+
self,
|
| 482 |
+
input_ids: torch.LongTensor | None = None,
|
| 483 |
+
attention_mask: torch.Tensor | None = None,
|
| 484 |
+
position_ids: torch.LongTensor | None = None,
|
| 485 |
+
past_key_values: Cache | None = None,
|
| 486 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 487 |
+
use_cache: bool | None = None,
|
| 488 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 489 |
+
) -> BaseModelOutputWithPast:
|
| 490 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 491 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 492 |
+
|
| 493 |
+
if inputs_embeds is None:
|
| 494 |
+
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
| 495 |
+
|
| 496 |
+
if use_cache and past_key_values is None:
|
| 497 |
+
past_key_values = DynamicCache(config=self.config)
|
| 498 |
+
|
| 499 |
+
if position_ids is None:
|
| 500 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 501 |
+
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
|
| 502 |
+
position_ids = position_ids.unsqueeze(0)
|
| 503 |
+
|
| 504 |
+
causal_mask = create_causal_mask(
|
| 505 |
+
config=self.config,
|
| 506 |
+
inputs_embeds=inputs_embeds,
|
| 507 |
+
attention_mask=attention_mask,
|
| 508 |
+
past_key_values=past_key_values,
|
| 509 |
+
position_ids=position_ids,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
hidden_states = inputs_embeds
|
| 513 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
|
| 514 |
+
|
| 515 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 516 |
+
hidden_states = decoder_layer(
|
| 517 |
+
hidden_states,
|
| 518 |
+
attention_mask=causal_mask,
|
| 519 |
+
position_embeddings=position_embeddings,
|
| 520 |
+
position_ids=position_ids,
|
| 521 |
+
past_key_values=past_key_values,
|
| 522 |
+
use_cache=use_cache,
|
| 523 |
+
**kwargs,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
hidden_states = self.norm(hidden_states)
|
| 527 |
+
return BaseModelOutputWithPast(
|
| 528 |
+
last_hidden_state=hidden_states,
|
| 529 |
+
past_key_values=past_key_values,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
@auto_docstring
|
| 534 |
+
class YoutuForCausalLM(YoutuPreTrainedModel, GenerationMixin):
|
| 535 |
+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 536 |
+
_tp_plan = {"lm_head": "colwise_gather_output"}
|
| 537 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 538 |
+
|
| 539 |
+
def __init__(self, config):
|
| 540 |
+
super().__init__(config)
|
| 541 |
+
self.model = YoutuModel(config)
|
| 542 |
+
self.vocab_size = config.vocab_size
|
| 543 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 544 |
+
|
| 545 |
+
# Initialize weights and apply final processing
|
| 546 |
+
self.post_init()
|
| 547 |
+
|
| 548 |
+
@can_return_tuple
|
| 549 |
+
@auto_docstring
|
| 550 |
+
def forward(
|
| 551 |
+
self,
|
| 552 |
+
input_ids: torch.LongTensor | None = None,
|
| 553 |
+
attention_mask: torch.Tensor | None = None,
|
| 554 |
+
position_ids: torch.LongTensor | None = None,
|
| 555 |
+
past_key_values: Cache | None = None,
|
| 556 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 557 |
+
labels: torch.LongTensor | None = None,
|
| 558 |
+
use_cache: bool | None = None,
|
| 559 |
+
logits_to_keep: int | torch.Tensor = 0,
|
| 560 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 561 |
+
) -> CausalLMOutputWithPast:
|
| 562 |
+
r"""
|
| 563 |
+
Example:
|
| 564 |
+
|
| 565 |
+
```python
|
| 566 |
+
>>> from transformers import AutoTokenizer, YoutuForCausalLM
|
| 567 |
+
|
| 568 |
+
>>> model = YoutuForCausalLM.from_pretrained("meta-youtu/Youtu-2-7b-hf")
|
| 569 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-youtu/Youtu-2-7b-hf")
|
| 570 |
+
|
| 571 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 572 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 573 |
+
|
| 574 |
+
>>> # Generate
|
| 575 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 576 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 577 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 578 |
+
```"""
|
| 579 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 580 |
+
input_ids=input_ids,
|
| 581 |
+
attention_mask=attention_mask,
|
| 582 |
+
position_ids=position_ids,
|
| 583 |
+
past_key_values=past_key_values,
|
| 584 |
+
inputs_embeds=inputs_embeds,
|
| 585 |
+
use_cache=use_cache,
|
| 586 |
+
**kwargs,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
hidden_states = outputs.last_hidden_state
|
| 590 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 591 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 592 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 593 |
+
|
| 594 |
+
loss = None
|
| 595 |
+
if labels is not None:
|
| 596 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 597 |
+
|
| 598 |
+
return CausalLMOutputWithPast(
|
| 599 |
+
loss=loss,
|
| 600 |
+
logits=logits,
|
| 601 |
+
past_key_values=outputs.past_key_values,
|
| 602 |
+
hidden_states=outputs.hidden_states,
|
| 603 |
+
attentions=outputs.attentions,
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
__all__ = ["YoutuPreTrainedModel", "YoutuModel", "YoutuForCausalLM"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/modular_youtu.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 the Tencent and HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 4 |
+
# and OPT implementations in this library. It has been modified from its
|
| 5 |
+
# original forms to accommodate minor architectural differences compared
|
| 6 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from huggingface_hub.dataclasses import strict
|
| 23 |
+
from torch import nn
|
| 24 |
+
|
| 25 |
+
from ... import initialization as init
|
| 26 |
+
from ...modeling_utils import PreTrainedModel
|
| 27 |
+
from ...utils import auto_docstring, logging
|
| 28 |
+
from ..deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
|
| 29 |
+
from ..deepseek_v3.modeling_deepseek_v3 import DeepseekV3Attention
|
| 30 |
+
from ..llama.modeling_llama import (
|
| 31 |
+
LlamaDecoderLayer,
|
| 32 |
+
LlamaForCausalLM,
|
| 33 |
+
LlamaModel,
|
| 34 |
+
LlamaPreTrainedModel,
|
| 35 |
+
LlamaRMSNorm,
|
| 36 |
+
LlamaRotaryEmbedding,
|
| 37 |
+
)
|
| 38 |
+
from ..qwen3.modeling_qwen3 import Qwen3MLP
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@auto_docstring(checkpoint="tencent/Youtu-LLM-2B")
|
| 45 |
+
@strict
|
| 46 |
+
class YoutuConfig(DeepseekV3Config):
|
| 47 |
+
r"""
|
| 48 |
+
rope_interleave (`bool`, *optional*, defaults to `True`):
|
| 49 |
+
Whether to interleave the rotary position embeddings.
|
| 50 |
+
embedding_initializer_range (`float`, *optional*):
|
| 51 |
+
The standard deviation of the truncated_normal_initializer for initializing all embedding matrices.
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
>>> from transformers import YoutuModel, YoutuConfig
|
| 55 |
+
>>> # Initializing a Youtu-LLM-2B style configuration
|
| 56 |
+
>>> configuration = YoutuConfig()
|
| 57 |
+
>>> # Accessing the model configuration
|
| 58 |
+
>>> configuration = model.config
|
| 59 |
+
```"""
|
| 60 |
+
|
| 61 |
+
model_type = "youtu"
|
| 62 |
+
base_model_tp_plan = {
|
| 63 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 64 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 65 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 66 |
+
}
|
| 67 |
+
attribute_map = {}
|
| 68 |
+
|
| 69 |
+
vocab_size: int = 128256
|
| 70 |
+
hidden_size: int = 2048
|
| 71 |
+
intermediate_size: int = 6144
|
| 72 |
+
num_hidden_layers: int = 32
|
| 73 |
+
num_attention_heads: int = 16
|
| 74 |
+
num_key_value_heads: int = 16
|
| 75 |
+
max_position_embeddings: int = 131072
|
| 76 |
+
initializer_range: float | None = None
|
| 77 |
+
embedding_initializer_range: float | None = None
|
| 78 |
+
pad_token_id: int | None = None
|
| 79 |
+
bos_token_id: int | None = 128000
|
| 80 |
+
eos_token_id: int | list[int] | None = 128001
|
| 81 |
+
tie_word_embeddings: bool = True
|
| 82 |
+
|
| 83 |
+
# remove unused attribute
|
| 84 |
+
n_shared_experts = AttributeError()
|
| 85 |
+
n_routed_experts = AttributeError()
|
| 86 |
+
routed_scaling_factor = AttributeError()
|
| 87 |
+
n_group = AttributeError()
|
| 88 |
+
topk_group = AttributeError()
|
| 89 |
+
num_experts_per_tok = AttributeError()
|
| 90 |
+
first_k_dense_replace = AttributeError()
|
| 91 |
+
norm_topk_prob = AttributeError()
|
| 92 |
+
pretraining_tp = AttributeError()
|
| 93 |
+
moe_intermediate_size = AttributeError()
|
| 94 |
+
|
| 95 |
+
def __post_init__(self, **kwargs):
|
| 96 |
+
if self.initializer_range is None:
|
| 97 |
+
if self.hidden_size != 0:
|
| 98 |
+
self.initializer_range = 2.0 / (5.0 * self.hidden_size) ** 0.5
|
| 99 |
+
else:
|
| 100 |
+
self.initializer_range = 0.02
|
| 101 |
+
|
| 102 |
+
self.embedding_initializer_range = self.embedding_initializer_range or 2.0 * self.initializer_range
|
| 103 |
+
super().__post_init__(**kwargs)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class YoutuRMSNorm(LlamaRMSNorm):
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class YoutuRotaryEmbedding(LlamaRotaryEmbedding):
|
| 111 |
+
pass
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class YoutuMLP(Qwen3MLP):
|
| 115 |
+
pass
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class YoutuAttention(DeepseekV3Attention):
|
| 119 |
+
pass
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class YoutuDecoderLayer(LlamaDecoderLayer):
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class YoutuPreTrainedModel(LlamaPreTrainedModel, PreTrainedModel):
|
| 127 |
+
@torch.no_grad()
|
| 128 |
+
def _init_weights(self, module):
|
| 129 |
+
PreTrainedModel._init_weights(self, module)
|
| 130 |
+
std = getattr(self.config, "initializer_range", 0.02)
|
| 131 |
+
embed_std = getattr(self.config, "embedding_initializer_range", 2 * std)
|
| 132 |
+
if isinstance(module, nn.Embedding):
|
| 133 |
+
init.normal_(module.weight, mean=0.0, std=embed_std)
|
| 134 |
+
if module.padding_idx is not None:
|
| 135 |
+
init.zeros_(module.weight.data[module.padding_idx])
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class YoutuModel(LlamaModel):
|
| 139 |
+
pass
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class YoutuForCausalLM(LlamaForCausalLM):
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
__all__ = [
|
| 147 |
+
"YoutuConfig",
|
| 148 |
+
"YoutuPreTrainedModel",
|
| 149 |
+
"YoutuModel",
|
| 150 |
+
"YoutuForCausalLM",
|
| 151 |
+
]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/testing_utils.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/training_args.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_070000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f9d782f9b8c989c4bdaccc104827d9b426cdf6c42c2bd671b6e40541ceb24c2
|
| 3 |
+
size 897562466
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_078000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68bdb7c05a6e6a3e90081c66c23125a3951666f75f5171b90f4ca8e23da89f69
|
| 3 |
+
size 897562466
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_079000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0c31ea7165f98543157bb82e8ff15f8183ee06eaa5cfa6901c6ed884fbe5e0ec
|
| 3 |
+
size 897562466
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_343000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5a4d7035ba11f7286021fe81bbe41d62ea19e278f5466a9c6506e495e025e2e6
|
| 3 |
+
size 897562466
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_352000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8cf82b122d41a279824fef01b76fe887328610fe616cccbda5186d6bed45dd29
|
| 3 |
+
size 897562466
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_390000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:84de046c4742e1f4294ba4ee53edbadd466c91ff6003748fbc7e2bcfc56e1d32
|
| 3 |
+
size 897562466
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_433000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:71cd93a38ffcecccd3b1e054151774919038ec40a6aec162ebcfa5db27fa7b46
|
| 3 |
+
size 897562466
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_471000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b1caa92ac5cea91c492b668a52dbce1d34b2cf72db6edbd2ee768d2608cf203
|
| 3 |
+
size 897562466
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_565000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9ebd372ab29ab0dde7b3aab27de2e20487c5fce1628a03df9d46b59795702246
|
| 3 |
+
size 897562466
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_571000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:477f5f24e284fcf6630d7a64ed678211de0b09b4ab11f149bc02dd43c78678da
|
| 3 |
+
size 897562466
|