JinghuiLuAstronaut commited on
Commit
8b31547
·
verified ·
1 Parent(s): 76bde08

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/runs/char_ar_lta_4gpu_5k_20260507.pid +1 -0
  2. LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/runs/char_ar_lta_4gpu_5k_20260507_lta.pid +1 -0
  3. LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/runs/char_ar_lta_4gpu_5k_20260507_lta_rerun.log +61 -0
  4. LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/sample_fully_coupled.py +146 -0
  5. LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/train_char.py +618 -0
  6. LTA_openwebtext_dualt/logs/elfopt_4gpu_debug_20260513/lta_owt_fast10k_len1024_elfopt_muon_ema_ddit768x12_4gpu_5epoch_20260513.log +94 -0
  7. LTA_openwebtext_dualt/logs/elfopt_4gpu_debug_20260513/lta_owt_fast10k_len1024_elfopt_muon_ema_ddit768x12_4gpu_5epoch_20260513_trace.log +1 -0
  8. LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/infer_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_step_0010000_t1p45.log +68 -0
  9. LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/infer_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_step_0020000_t1p45.log +68 -0
  10. LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/infer_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_step_0030000_t1p45.log +68 -0
  11. LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/processed_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_steps128_c1024_t1p45_n1024.txt +4 -0
  12. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/activate.bat +71 -0
  13. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/activate.fish +124 -0
  14. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/f2py +10 -0
  15. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/pydoc.bat +22 -0
  16. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/audio_utils.py +1254 -0
  17. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/__init__.py +13 -0
  18. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/add_new_model_like.py +790 -0
  19. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/chat.py +673 -0
  20. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/download.py +40 -0
  21. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/serve.py +241 -0
  22. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/system.py +139 -0
  23. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/transformers.py +41 -0
  24. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/distributed/__init__.py +33 -0
  25. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/distributed/configuration_utils.py +110 -0
  26. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/hyperparameter_search.py +123 -0
  27. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_transforms.py +1073 -0
  28. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/__init__.py +30 -0
  29. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/configuration_gemma3.py +225 -0
  30. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/image_processing_gemma3.py +250 -0
  31. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/image_processing_pil_gemma3.py +225 -0
  32. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py +1118 -0
  33. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/modular_gemma3.py +941 -0
  34. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/processing_gemma3.py +165 -0
  35. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/__init__.py +27 -0
  36. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/configuration_youtu.py +107 -0
  37. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/modeling_youtu.py +607 -0
  38. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/modular_youtu.py +151 -0
  39. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/testing_utils.py +0 -0
  40. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/training_args.py +0 -0
  41. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_070000.pt +3 -0
  42. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_078000.pt +3 -0
  43. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_079000.pt +3 -0
  44. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_343000.pt +3 -0
  45. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_352000.pt +3 -0
  46. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_390000.pt +3 -0
  47. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_433000.pt +3 -0
  48. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_471000.pt +3 -0
  49. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_565000.pt +3 -0
  50. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_571000.pt +3 -0
LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/runs/char_ar_lta_4gpu_5k_20260507.pid ADDED
@@ -0,0 +1 @@
 
 
1
+ 354158
LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/runs/char_ar_lta_4gpu_5k_20260507_lta.pid ADDED
@@ -0,0 +1 @@
 
 
1
+ 354493
LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/runs/char_ar_lta_4gpu_5k_20260507_lta_rerun.log ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ *****************************************
3
+ Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
4
+ *****************************************
5
+ [setup] device=cuda:0 rank=0 world_size=4
6
+ [rank2]:[W507 18:31:20.830802063 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
7
+ [rank0]:[W507 18:31:20.831852183 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
8
+ [rank3]:[W507 18:31:20.833351576 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
9
+ [rank1]:[W507 18:31:20.834403717 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
10
+ [data] chars=1115394 vocab=65 train=1003854 val=111540
11
+ [lta] step=1 loss=4.2253 elapsed=0.6s
12
+ [lta] step=100 loss=1.8800 elapsed=1.6s
13
+ [lta] step=200 loss=1.7154 elapsed=2.6s
14
+ [lta] step=300 loss=1.8433 elapsed=3.7s
15
+ [lta] step=400 loss=1.5804 elapsed=4.7s
16
+ [lta] step=500 loss=1.3491 elapsed=5.7s
17
+ [lta] step=600 loss=1.5344 elapsed=7.2s
18
+ [lta] step=700 loss=1.2540 elapsed=8.3s
19
+ [lta] step=800 loss=1.3757 elapsed=9.3s
20
+ [lta] step=900 loss=1.4476 elapsed=10.3s
21
+ [lta] step=1000 loss=1.4170 elapsed=11.3s
22
+ [lta] step=1100 loss=1.5610 elapsed=12.8s
23
+ [lta] step=1200 loss=1.4933 elapsed=13.8s
24
+ [lta] step=1300 loss=1.5656 elapsed=14.9s
25
+ [lta] step=1400 loss=1.5198 elapsed=15.9s
26
+ [lta] step=1500 loss=1.4798 elapsed=17.0s
27
+ [lta] step=1600 loss=1.5783 elapsed=18.5s
28
+ [lta] step=1700 loss=1.1984 elapsed=19.5s
29
+ [lta] step=1800 loss=1.2941 elapsed=20.5s
30
+ [lta] step=1900 loss=1.5220 elapsed=21.5s
31
+ [lta] step=2000 loss=1.2615 elapsed=22.6s
32
+ [lta] step=2100 loss=1.3370 elapsed=24.1s
33
+ [lta] step=2200 loss=1.1854 elapsed=25.1s
34
+ [lta] step=2300 loss=0.9726 elapsed=26.1s
35
+ [lta] step=2400 loss=1.4613 elapsed=27.1s
36
+ [lta] step=2500 loss=1.3016 elapsed=28.2s
37
+ [lta] step=2600 loss=1.3408 elapsed=29.7s
38
+ [lta] step=2700 loss=1.3022 elapsed=30.7s
39
+ [lta] step=2800 loss=1.4492 elapsed=31.7s
40
+ [lta] step=2900 loss=1.1530 elapsed=32.7s
41
+ [lta] step=3000 loss=1.4642 elapsed=33.8s
42
+ [lta] step=3100 loss=1.2645 elapsed=35.3s
43
+ [lta] step=3200 loss=1.4777 elapsed=36.3s
44
+ [lta] step=3300 loss=1.0923 elapsed=37.4s
45
+ [lta] step=3400 loss=1.1992 elapsed=38.5s
46
+ [lta] step=3500 loss=1.4760 elapsed=39.6s
47
+ [lta] step=3600 loss=1.5702 elapsed=41.0s
48
+ [lta] step=3700 loss=1.5327 elapsed=42.1s
49
+ [lta] step=3800 loss=1.5319 elapsed=43.1s
50
+ [lta] step=3900 loss=1.3098 elapsed=44.3s
51
+ [lta] step=4000 loss=1.6050 elapsed=45.3s
52
+ [lta] step=4100 loss=1.2478 elapsed=46.8s
53
+ [lta] step=4200 loss=1.3497 elapsed=47.8s
54
+ [lta] step=4300 loss=1.3263 elapsed=48.8s
55
+ [lta] step=4400 loss=1.3406 elapsed=49.9s
56
+ [lta] step=4500 loss=1.3295 elapsed=50.9s
57
+ [lta] step=4600 loss=1.5340 elapsed=52.4s
58
+ [lta] step=4700 loss=1.4847 elapsed=53.4s
59
+ [lta] step=4800 loss=1.1464 elapsed=54.4s
60
+ [lta] step=4900 loss=1.4102 elapsed=55.5s
61
+ [lta] step=5000 loss=1.4638 elapsed=56.5s
LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/sample_fully_coupled.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import csv
6
+ import json
7
+ import math
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ SCRIPT_DIR = Path(__file__).resolve().parent
15
+ if str(SCRIPT_DIR) not in sys.path:
16
+ sys.path.insert(0, str(SCRIPT_DIR))
17
+
18
+ from train_char import CharTokenizer, ModelConfig, TinyTransformer, standard_gamma, text_stats
19
+
20
+
21
+ def pick_device(name: str) -> torch.device:
22
+ if name != "auto":
23
+ return torch.device(name)
24
+ if torch.cuda.is_available():
25
+ return torch.device("cuda")
26
+ if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
27
+ return torch.device("mps")
28
+ return torch.device("cpu")
29
+
30
+
31
+ def decode(
32
+ model: TinyTransformer,
33
+ tokenizer: CharTokenizer,
34
+ *,
35
+ length: int,
36
+ steps: int,
37
+ c_min: float,
38
+ c_max: float,
39
+ temp: float,
40
+ final_from: str,
41
+ seed: int,
42
+ device: torch.device,
43
+ ) -> str:
44
+ torch.manual_seed(seed)
45
+ eps = 1e-8
46
+ vocab_size = tokenizer.vocab_size
47
+ alpha = torch.full((1, length, vocab_size), 1.0 / vocab_size, device=device).clamp_min(eps)
48
+ probs = standard_gamma(alpha).clamp_min(eps)
49
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(eps)
50
+ last_endpoint = probs
51
+ for step in range(steps):
52
+ t_value = (step + 1) / max(steps, 1)
53
+ t = torch.full((1,), t_value, device=device)
54
+ logits = model(probs, t) / temp
55
+ endpoint = F.softmax(logits, dim=-1)
56
+ last_endpoint = endpoint
57
+ support_t = t_value
58
+ semantic_t = t_value
59
+ forward_endpoint = (1.0 - semantic_t) * probs + semantic_t * endpoint
60
+ mean = (1.0 - support_t) / float(vocab_size) + support_t * forward_endpoint
61
+ mean = mean.clamp_min(eps)
62
+ mean = mean / mean.sum(dim=-1, keepdim=True).clamp_min(eps)
63
+ conc = math.exp(math.log(c_min) + support_t * math.log(c_max / c_min))
64
+ sample = standard_gamma((mean * conc).clamp_min(eps)).clamp_min(eps)
65
+ probs = sample / sample.sum(dim=-1, keepdim=True).clamp_min(eps)
66
+ if final_from == "state":
67
+ final = probs
68
+ elif final_from == "endpoint":
69
+ final = last_endpoint
70
+ elif final_from == "blend":
71
+ final = 0.5 * probs + 0.5 * last_endpoint
72
+ else:
73
+ raise ValueError(final_from)
74
+ ids = final.argmax(dim=-1)[0]
75
+ return tokenizer.decode(ids)
76
+
77
+
78
+ def main() -> None:
79
+ p = argparse.ArgumentParser()
80
+ p.add_argument("--checkpoint", required=True)
81
+ p.add_argument("--out_dir", required=True)
82
+ p.add_argument("--length", type=int, default=128)
83
+ p.add_argument("--seed", type=int, default=20260507)
84
+ p.add_argument("--device", default="auto")
85
+ args = p.parse_args()
86
+
87
+ device = pick_device(args.device)
88
+ print(f"[setup] device={device}", flush=True)
89
+ ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
90
+ cfg = ModelConfig(**ckpt["model_config"])
91
+ tok_data = ckpt["extra"]["tokenizer"]
92
+ tokenizer = CharTokenizer("".join(tok_data["itos"]))
93
+ tokenizer.itos = tok_data["itos"]
94
+ tokenizer.stoi = tok_data["stoi"]
95
+ tokenizer.vocab_size = tok_data["vocab_size"]
96
+ model = TinyTransformer(cfg).to(device)
97
+ model.load_state_dict(ckpt["model"])
98
+ model.eval()
99
+
100
+ configs = []
101
+ for steps in [128, 256, 512, 1024]:
102
+ for c_max in [64.0, 16.0, 4.0, 1.0]:
103
+ for temp in [0.8, 1.0, 1.3, 1.8]:
104
+ for final_from in ["state", "endpoint", "blend"]:
105
+ configs.append((steps, c_max, temp, final_from))
106
+
107
+ out_dir = Path(args.out_dir)
108
+ out_dir.mkdir(parents=True, exist_ok=True)
109
+ rows = []
110
+ for i, (steps, c_max, temp, final_from) in enumerate(configs):
111
+ name = f"steps{steps}_c{c_max:g}_temp{str(temp).replace('.', 'p')}_{final_from}"
112
+ text = decode(
113
+ model,
114
+ tokenizer,
115
+ length=args.length,
116
+ steps=steps,
117
+ c_min=1.0,
118
+ c_max=c_max,
119
+ temp=temp,
120
+ final_from=final_from,
121
+ seed=args.seed,
122
+ device=device,
123
+ )
124
+ stats = text_stats(text)
125
+ row = {
126
+ "name": name,
127
+ "steps": steps,
128
+ "c_max": c_max,
129
+ "temp": temp,
130
+ "final_from": final_from,
131
+ **stats,
132
+ }
133
+ rows.append(row)
134
+ (out_dir / f"{name}.txt").write_text(text, encoding="utf-8")
135
+ if i % 12 == 0:
136
+ print("[sample]", row, flush=True)
137
+
138
+ keys = list(rows[0].keys())
139
+ with (out_dir / "summary.tsv").open("w", encoding="utf-8") as f:
140
+ writer = csv.DictWriter(f, fieldnames=keys, delimiter="\t")
141
+ writer.writeheader()
142
+ writer.writerows(rows)
143
+
144
+
145
+ if __name__ == "__main__":
146
+ main()
LTA_openwebtext_dualt/experiments/nanogpt_tinyshakespeare_char/train_char.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ import math
7
+ import os
8
+ import time
9
+ import urllib.request
10
+ from dataclasses import asdict, dataclass
11
+ from pathlib import Path
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.nn.parallel import DistributedDataParallel as DDP
18
+
19
+
20
+ DATA_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
21
+
22
+
23
+ def setup_distributed(name: str) -> tuple[torch.device, int, int, int, bool]:
24
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
25
+ rank = int(os.environ.get("RANK", "0"))
26
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
27
+ if world_size > 1:
28
+ if not torch.cuda.is_available():
29
+ raise RuntimeError("DDP mode expects CUDA. Run single-process for CPU/MPS.")
30
+ torch.cuda.set_device(local_rank)
31
+ dist.init_process_group(backend="nccl")
32
+ return torch.device("cuda", local_rank), rank, local_rank, world_size, True
33
+ if name != "auto":
34
+ return torch.device(name), rank, local_rank, world_size, False
35
+ if torch.cuda.is_available():
36
+ return torch.device("cuda"), rank, local_rank, world_size, False
37
+ if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
38
+ return torch.device("mps"), rank, local_rank, world_size, False
39
+ return torch.device("cpu"), rank, local_rank, world_size, False
40
+
41
+
42
+ def cleanup_distributed(is_ddp: bool) -> None:
43
+ if is_ddp and dist.is_initialized():
44
+ dist.destroy_process_group()
45
+
46
+
47
+ def is_main_process(rank: int) -> bool:
48
+ return rank == 0
49
+
50
+
51
+ class CharTokenizer:
52
+ def __init__(self, text: str):
53
+ chars = sorted(set(text))
54
+ self.itos = chars
55
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
56
+ self.vocab_size = len(chars)
57
+
58
+ def encode(self, text: str) -> list[int]:
59
+ return [self.stoi[ch] for ch in text]
60
+
61
+ def decode(self, ids: list[int] | torch.Tensor) -> str:
62
+ if isinstance(ids, torch.Tensor):
63
+ ids = ids.detach().cpu().tolist()
64
+ return "".join(self.itos[int(i)] for i in ids)
65
+
66
+ def to_json(self) -> dict:
67
+ return {"itos": self.itos, "stoi": self.stoi, "vocab_size": self.vocab_size}
68
+
69
+
70
+ def ensure_tinyshakespeare(data_dir: Path) -> None:
71
+ data_dir.mkdir(parents=True, exist_ok=True)
72
+ path = data_dir / "input.txt"
73
+ if not path.exists():
74
+ print(f"[data] downloading {DATA_URL}", flush=True)
75
+ urllib.request.urlretrieve(DATA_URL, path)
76
+
77
+
78
+ def load_tinyshakespeare(data_dir: Path) -> tuple[str, CharTokenizer, torch.Tensor, torch.Tensor]:
79
+ path = data_dir / "input.txt"
80
+ text = path.read_text(encoding="utf-8")
81
+ tokenizer = CharTokenizer(text)
82
+ ids = torch.tensor(tokenizer.encode(text), dtype=torch.long)
83
+ split = int(0.9 * len(ids))
84
+ return text, tokenizer, ids[:split], ids[split:]
85
+
86
+
87
+ def get_batch(data: torch.Tensor, *, batch_size: int, block_size: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
88
+ ix = torch.randint(0, len(data) - block_size - 1, (batch_size,))
89
+ x = torch.stack([data[i : i + block_size] for i in ix])
90
+ y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
91
+ return x.to(device), y.to(device)
92
+
93
+
94
+ def get_block_batch(data: torch.Tensor, *, batch_size: int, block_size: int, device: torch.device) -> torch.Tensor:
95
+ ix = torch.randint(0, len(data) - block_size, (batch_size,))
96
+ x = torch.stack([data[i : i + block_size] for i in ix])
97
+ return x.to(device)
98
+
99
+
100
+ def _wrapped_window(data: torch.Tensor, start: int, width: int) -> torch.Tensor:
101
+ end = start + width
102
+ if end <= len(data):
103
+ return data[start:end]
104
+ return torch.cat([data[start:], data[: end % len(data)]], dim=0)
105
+
106
+
107
+ def get_stream_batch(
108
+ data: torch.Tensor,
109
+ *,
110
+ batch_size: int,
111
+ block_size: int,
112
+ device: torch.device,
113
+ cursor: int,
114
+ rank: int = 0,
115
+ world_size: int = 1,
116
+ ) -> tuple[torch.Tensor, torch.Tensor, int]:
117
+ width = block_size + 1
118
+ base = (cursor + rank * batch_size * width) % len(data)
119
+ samples = torch.stack([_wrapped_window(data, (base + i * width) % len(data), width) for i in range(batch_size)])
120
+ next_cursor = (cursor + world_size * batch_size * width) % len(data)
121
+ return samples[:, :block_size].to(device), samples[:, 1:].to(device), next_cursor
122
+
123
+
124
+ def get_stream_block_batch(
125
+ data: torch.Tensor,
126
+ *,
127
+ batch_size: int,
128
+ block_size: int,
129
+ device: torch.device,
130
+ cursor: int,
131
+ rank: int = 0,
132
+ world_size: int = 1,
133
+ ) -> tuple[torch.Tensor, int]:
134
+ width = block_size
135
+ base = (cursor + rank * batch_size * width) % len(data)
136
+ samples = torch.stack([_wrapped_window(data, (base + i * width) % len(data), width) for i in range(batch_size)])
137
+ next_cursor = (cursor + world_size * batch_size * width) % len(data)
138
+ return samples.to(device), next_cursor
139
+
140
+
141
+ @dataclass
142
+ class ModelConfig:
143
+ vocab_size: int
144
+ block_size: int = 128
145
+ n_layer: int = 4
146
+ n_head: int = 4
147
+ n_embd: int = 128
148
+ dropout: float = 0.1
149
+ causal: bool = True
150
+ input_kind: str = "tokens"
151
+
152
+
153
+ class CausalSelfAttention(nn.Module):
154
+ def __init__(self, cfg: ModelConfig):
155
+ super().__init__()
156
+ assert cfg.n_embd % cfg.n_head == 0
157
+ self.n_head = cfg.n_head
158
+ self.dropout = cfg.dropout
159
+ self.c_attn = nn.Linear(cfg.n_embd, 3 * cfg.n_embd)
160
+ self.c_proj = nn.Linear(cfg.n_embd, cfg.n_embd)
161
+ self.attn_drop = nn.Dropout(cfg.dropout)
162
+ self.resid_drop = nn.Dropout(cfg.dropout)
163
+ self.causal = cfg.causal
164
+ self.register_buffer("tril", torch.tril(torch.ones(cfg.block_size, cfg.block_size)).view(1, 1, cfg.block_size, cfg.block_size))
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ b, t, c = x.shape
168
+ q, k, v = self.c_attn(x).split(c, dim=2)
169
+ q = q.view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
170
+ k = k.view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
171
+ v = v.view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
172
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
173
+ if self.causal:
174
+ att = att.masked_fill(self.tril[:, :, :t, :t] == 0, float("-inf"))
175
+ att = F.softmax(att, dim=-1)
176
+ att = self.attn_drop(att)
177
+ y = att @ v
178
+ y = y.transpose(1, 2).contiguous().view(b, t, c)
179
+ return self.resid_drop(self.c_proj(y))
180
+
181
+
182
+ class Block(nn.Module):
183
+ def __init__(self, cfg: ModelConfig):
184
+ super().__init__()
185
+ self.ln1 = nn.LayerNorm(cfg.n_embd)
186
+ self.attn = CausalSelfAttention(cfg)
187
+ self.ln2 = nn.LayerNorm(cfg.n_embd)
188
+ self.mlp = nn.Sequential(
189
+ nn.Linear(cfg.n_embd, 4 * cfg.n_embd),
190
+ nn.GELU(),
191
+ nn.Linear(4 * cfg.n_embd, cfg.n_embd),
192
+ nn.Dropout(cfg.dropout),
193
+ )
194
+
195
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
196
+ x = x + self.attn(self.ln1(x))
197
+ x = x + self.mlp(self.ln2(x))
198
+ return x
199
+
200
+
201
+ class TinyTransformer(nn.Module):
202
+ def __init__(self, cfg: ModelConfig):
203
+ super().__init__()
204
+ self.cfg = cfg
205
+ self.token_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd)
206
+ self.prob_proj = nn.Linear(cfg.vocab_size, cfg.n_embd, bias=False)
207
+ self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd)
208
+ self.time_mlp = nn.Sequential(nn.Linear(1, cfg.n_embd), nn.SiLU(), nn.Linear(cfg.n_embd, cfg.n_embd))
209
+ self.drop = nn.Dropout(cfg.dropout)
210
+ self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
211
+ self.ln_f = nn.LayerNorm(cfg.n_embd)
212
+ self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
213
+ if cfg.input_kind == "tokens":
214
+ self.lm_head.weight = self.token_emb.weight
215
+ self.apply(self._init_weights)
216
+
217
+ def _init_weights(self, module: nn.Module) -> None:
218
+ if isinstance(module, nn.Linear):
219
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
220
+ if module.bias is not None:
221
+ nn.init.zeros_(module.bias)
222
+ elif isinstance(module, nn.Embedding):
223
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
224
+
225
+ def forward(self, x: torch.Tensor, t: torch.Tensor | None = None) -> torch.Tensor:
226
+ b, seq_len = x.shape[:2]
227
+ pos = torch.arange(seq_len, device=x.device)
228
+ if self.cfg.input_kind == "tokens":
229
+ h = self.token_emb(x)
230
+ elif self.cfg.input_kind == "probs":
231
+ h = self.prob_proj(x.float())
232
+ if t is None:
233
+ raise ValueError("LTA/prob model requires time t")
234
+ h = h + self.time_mlp(t.float().view(b, 1)).view(b, 1, -1)
235
+ else:
236
+ raise ValueError(f"unknown input_kind: {self.cfg.input_kind}")
237
+ h = self.drop(h + self.pos_emb(pos).view(1, seq_len, -1))
238
+ for block in self.blocks:
239
+ h = block(h)
240
+ return self.lm_head(self.ln_f(h))
241
+
242
+
243
+ @dataclass
244
+ class LTAConfig:
245
+ c_min: float = 1.0
246
+ c_max: float = 64.0
247
+ endpoint_mode: str = "full_vocab_wrong"
248
+ t_mode: str = "same"
249
+ eps: float = 1e-8
250
+
251
+
252
+ def concentration(t: torch.Tensor, c_min: float, c_max: float) -> torch.Tensor:
253
+ return torch.exp(torch.log(torch.tensor(c_min, device=t.device)) + t * math.log(c_max / c_min))
254
+
255
+
256
+ def standard_gamma(alpha: torch.Tensor) -> torch.Tensor:
257
+ # MPS does not implement aten::_standard_gamma yet. Sampling on CPU is
258
+ # plenty fast for this tiny char-level experiment, while the Transformer
259
+ # still runs on the accelerator.
260
+ if alpha.device.type == "mps":
261
+ return torch._standard_gamma(alpha.cpu()).to(alpha.device)
262
+ return torch._standard_gamma(alpha)
263
+
264
+
265
+ def corrupt_categorical_simplex(ids: torch.Tensor, vocab_size: int, cfg: LTAConfig) -> tuple[torch.Tensor, torch.Tensor]:
266
+ b, seq_len = ids.shape
267
+ device = ids.device
268
+ support_t = torch.rand(b, device=device)
269
+ if cfg.t_mode == "same":
270
+ semantic_t = support_t
271
+ elif cfg.t_mode == "independent":
272
+ semantic_t = torch.rand(b, device=device)
273
+ else:
274
+ raise ValueError(f"unknown t_mode: {cfg.t_mode}")
275
+
276
+ gold = F.one_hot(ids, vocab_size).float()
277
+ wrong_ids = torch.randint(0, vocab_size, ids.shape, device=device)
278
+ wrong = F.one_hot(wrong_ids, vocab_size).float()
279
+ endpoint = semantic_t.view(b, 1, 1) * gold + (1.0 - semantic_t).view(b, 1, 1) * wrong
280
+
281
+ support = support_t.view(b, 1, 1)
282
+ mean = (1.0 - support) / float(vocab_size) + support * endpoint
283
+ mean = mean.clamp_min(cfg.eps)
284
+ mean = mean / mean.sum(dim=-1, keepdim=True).clamp_min(cfg.eps)
285
+ conc = concentration(support_t, cfg.c_min, cfg.c_max).view(b, 1, 1)
286
+ alpha = (mean * conc).clamp_min(cfg.eps)
287
+ state = standard_gamma(alpha).clamp_min(cfg.eps)
288
+ state = state / state.sum(dim=-1, keepdim=True).clamp_min(cfg.eps)
289
+ return state, support_t
290
+
291
+
292
+ @torch.no_grad()
293
+ def estimate_ar_loss(model: TinyTransformer, data: torch.Tensor, args, device: torch.device, eval_iters: int) -> float:
294
+ model.eval()
295
+ losses = []
296
+ for _ in range(eval_iters):
297
+ x, y = get_batch(data, batch_size=args.batch_size, block_size=args.block_size, device=device)
298
+ logits = model(x)
299
+ losses.append(F.cross_entropy(logits.view(-1, logits.size(-1)), y.reshape(-1)).item())
300
+ model.train()
301
+ return float(sum(losses) / len(losses))
302
+
303
+
304
+ @torch.no_grad()
305
+ def estimate_lta_loss(model: TinyTransformer, data: torch.Tensor, args, lta_cfg: LTAConfig, device: torch.device, eval_iters: int) -> float:
306
+ model.eval()
307
+ losses = []
308
+ for _ in range(eval_iters):
309
+ ids = get_block_batch(data, batch_size=args.batch_size, block_size=args.block_size, device=device)
310
+ state, t = corrupt_categorical_simplex(ids, model.cfg.vocab_size, lta_cfg)
311
+ logits = model(state, t)
312
+ losses.append(F.cross_entropy(logits.view(-1, logits.size(-1)), ids.reshape(-1)).item())
313
+ model.train()
314
+ return float(sum(losses) / len(losses))
315
+
316
+
317
+ @torch.no_grad()
318
+ def generate_ar(model: TinyTransformer, tokenizer: CharTokenizer, *, length: int, temp: float, device: torch.device, seed: str = "\n") -> str:
319
+ model.eval()
320
+ ids = torch.tensor([tokenizer.encode(seed)], dtype=torch.long, device=device)
321
+ for _ in range(length):
322
+ idx = ids[:, -model.cfg.block_size :]
323
+ logits = model(idx)[:, -1, :]
324
+ if temp <= 0:
325
+ nxt = logits.argmax(dim=-1, keepdim=True)
326
+ else:
327
+ probs = F.softmax(logits / temp, dim=-1)
328
+ nxt = torch.multinomial(probs, num_samples=1)
329
+ ids = torch.cat([ids, nxt], dim=1)
330
+ return tokenizer.decode(ids[0])
331
+
332
+
333
+ @torch.no_grad()
334
+ def generate_lta(
335
+ model: TinyTransformer,
336
+ tokenizer: CharTokenizer,
337
+ *,
338
+ length: int,
339
+ steps: int,
340
+ c_min: float,
341
+ c_max: float,
342
+ temp: float,
343
+ device: torch.device,
344
+ ) -> str:
345
+ model.eval()
346
+ eps = 1e-8
347
+ vocab_size = tokenizer.vocab_size
348
+ alpha = torch.full((1, length, vocab_size), 1.0 / vocab_size, device=device).clamp_min(eps)
349
+ probs = standard_gamma(alpha).clamp_min(eps)
350
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(eps)
351
+ last_endpoint = probs
352
+ for step in range(steps):
353
+ t_value = (step + 1) / max(steps, 1)
354
+ t = torch.full((1,), t_value, device=device)
355
+ logits = model(probs, t) / temp
356
+ endpoint = F.softmax(logits, dim=-1)
357
+ last_endpoint = endpoint
358
+ support_t = t_value
359
+ semantic_t = t_value
360
+ forward_endpoint = (1.0 - semantic_t) * probs + semantic_t * endpoint
361
+ mean = (1.0 - support_t) / float(vocab_size) + support_t * forward_endpoint
362
+ mean = mean.clamp_min(eps)
363
+ mean = mean / mean.sum(dim=-1, keepdim=True).clamp_min(eps)
364
+ conc = math.exp(math.log(c_min) + support_t * math.log(c_max / c_min))
365
+ sample = standard_gamma((mean * conc).clamp_min(eps)).clamp_min(eps)
366
+ probs = sample / sample.sum(dim=-1, keepdim=True).clamp_min(eps)
367
+ final = 0.5 * probs + 0.5 * last_endpoint
368
+ ids = final.argmax(dim=-1)[0]
369
+ return tokenizer.decode(ids)
370
+
371
+
372
+ def text_stats(text: str) -> dict[str, float]:
373
+ chars = list(text)
374
+ counts = {}
375
+ for ch in chars:
376
+ counts[ch] = counts.get(ch, 0) + 1
377
+ n = max(len(chars), 1)
378
+ entropy = -sum((c / n) * math.log(c / n) for c in counts.values())
379
+ bigrams = list(zip(chars, chars[1:]))
380
+ distinct_2 = len(set(bigrams)) / max(len(bigrams), 1)
381
+ return {"char_entropy": entropy, "distinct_2": distinct_2, "length": float(len(text))}
382
+
383
+
384
+ def save_checkpoint(path: Path, model: TinyTransformer, optimizer: torch.optim.Optimizer, step: int, cfg: ModelConfig, extra: dict) -> None:
385
+ path.parent.mkdir(parents=True, exist_ok=True)
386
+ torch.save(
387
+ {
388
+ "model": model.state_dict(),
389
+ "optimizer": optimizer.state_dict(),
390
+ "step": step,
391
+ "model_config": asdict(cfg),
392
+ "extra": extra,
393
+ },
394
+ path,
395
+ )
396
+
397
+
398
+ def maybe_resume(path: str | None, model: TinyTransformer, optimizer: torch.optim.Optimizer, device: torch.device) -> int:
399
+ if not path:
400
+ return 0
401
+ ckpt = torch.load(path, map_location=device)
402
+ model.load_state_dict(ckpt["model"])
403
+ if "optimizer" in ckpt:
404
+ optimizer.load_state_dict(ckpt["optimizer"])
405
+ return int(ckpt.get("step", 0))
406
+
407
+
408
+ def train_ar(
409
+ args,
410
+ tokenizer: CharTokenizer,
411
+ train_data: torch.Tensor,
412
+ val_data: torch.Tensor,
413
+ device: torch.device,
414
+ *,
415
+ rank: int,
416
+ local_rank: int,
417
+ is_ddp: bool,
418
+ ) -> None:
419
+ cfg = ModelConfig(
420
+ vocab_size=tokenizer.vocab_size,
421
+ block_size=args.block_size,
422
+ n_layer=args.n_layer,
423
+ n_head=args.n_head,
424
+ n_embd=args.n_embd,
425
+ dropout=args.dropout,
426
+ causal=True,
427
+ input_kind="tokens",
428
+ )
429
+ raw_model = TinyTransformer(cfg).to(device)
430
+ model = DDP(raw_model, device_ids=[local_rank], find_unused_parameters=True) if is_ddp else raw_model
431
+ optimizer = torch.optim.AdamW(raw_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
432
+ start_step = maybe_resume(args.resume_path, raw_model, optimizer, device)
433
+ out_dir = Path(args.out_dir) / "ar"
434
+ if is_main_process(rank):
435
+ out_dir.mkdir(parents=True, exist_ok=True)
436
+ log_path = out_dir / "metrics.jsonl"
437
+ t0 = time.time()
438
+ stream_world_size = dist.get_world_size() if is_ddp and dist.is_initialized() else 1
439
+ stream_cursor = (start_step * args.batch_size * (args.block_size + 1) * stream_world_size) % len(train_data)
440
+ for step in range(start_step + 1, args.steps + 1):
441
+ if args.data_mode == "stream":
442
+ x, y, stream_cursor = get_stream_batch(
443
+ train_data,
444
+ batch_size=args.batch_size,
445
+ block_size=args.block_size,
446
+ device=device,
447
+ cursor=stream_cursor,
448
+ rank=rank,
449
+ world_size=stream_world_size,
450
+ )
451
+ else:
452
+ x, y = get_batch(train_data, batch_size=args.batch_size, block_size=args.block_size, device=device)
453
+ logits = model(x)
454
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.reshape(-1))
455
+ optimizer.zero_grad(set_to_none=True)
456
+ loss.backward()
457
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
458
+ optimizer.step()
459
+ if is_main_process(rank) and (step % args.log_interval == 0 or step == 1):
460
+ print(f"[ar] step={step} loss={loss.item():.4f} elapsed={time.time() - t0:.1f}s", flush=True)
461
+ if step % args.eval_interval == 0 or step == args.steps:
462
+ if is_ddp:
463
+ dist.barrier()
464
+ if is_main_process(rank):
465
+ val_loss = estimate_ar_loss(raw_model, val_data, args, device, args.eval_iters)
466
+ sample = generate_ar(raw_model, tokenizer, length=args.sample_len, temp=args.ar_temp, device=device)
467
+ row = {"step": step, "train_loss": float(loss.item()), "val_loss": val_loss, **text_stats(sample)}
468
+ with log_path.open("a", encoding="utf-8") as f:
469
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
470
+ (out_dir / f"sample_step{step:05d}.txt").write_text(sample, encoding="utf-8")
471
+ save_checkpoint(out_dir / "latest.pt", raw_model, optimizer, step, cfg, {"mode": "ar", "data_mode": args.data_mode, "tokenizer": tokenizer.to_json()})
472
+ if is_ddp:
473
+ dist.barrier()
474
+ if is_main_process(rank):
475
+ save_checkpoint(out_dir / f"step_{args.steps:05d}.pt", raw_model, optimizer, args.steps, cfg, {"mode": "ar", "data_mode": args.data_mode, "tokenizer": tokenizer.to_json()})
476
+
477
+
478
+ def train_lta(
479
+ args,
480
+ tokenizer: CharTokenizer,
481
+ train_data: torch.Tensor,
482
+ val_data: torch.Tensor,
483
+ device: torch.device,
484
+ *,
485
+ rank: int,
486
+ local_rank: int,
487
+ is_ddp: bool,
488
+ ) -> None:
489
+ cfg = ModelConfig(
490
+ vocab_size=tokenizer.vocab_size,
491
+ block_size=args.block_size,
492
+ n_layer=args.n_layer,
493
+ n_head=args.n_head,
494
+ n_embd=args.n_embd,
495
+ dropout=args.dropout,
496
+ causal=False,
497
+ input_kind="probs",
498
+ )
499
+ lta_cfg = LTAConfig(c_min=args.c_min, c_max=args.c_max, t_mode=args.t_mode)
500
+ raw_model = TinyTransformer(cfg).to(device)
501
+ model = DDP(raw_model, device_ids=[local_rank], find_unused_parameters=True) if is_ddp else raw_model
502
+ optimizer = torch.optim.AdamW(raw_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
503
+ start_step = maybe_resume(args.resume_path, raw_model, optimizer, device)
504
+ out_dir = Path(args.out_dir) / ("fully_coupled" if args.mode == "fully_coupled" else "lta")
505
+ if is_main_process(rank):
506
+ out_dir.mkdir(parents=True, exist_ok=True)
507
+ log_path = out_dir / "metrics.jsonl"
508
+ t0 = time.time()
509
+ stream_world_size = dist.get_world_size() if is_ddp and dist.is_initialized() else 1
510
+ stream_cursor = (start_step * args.batch_size * args.block_size * stream_world_size) % len(train_data)
511
+ for step in range(start_step + 1, args.steps + 1):
512
+ if args.data_mode == "stream":
513
+ ids, stream_cursor = get_stream_block_batch(
514
+ train_data,
515
+ batch_size=args.batch_size,
516
+ block_size=args.block_size,
517
+ device=device,
518
+ cursor=stream_cursor,
519
+ rank=rank,
520
+ world_size=stream_world_size,
521
+ )
522
+ else:
523
+ ids = get_block_batch(train_data, batch_size=args.batch_size, block_size=args.block_size, device=device)
524
+ state, t = corrupt_categorical_simplex(ids, tokenizer.vocab_size, lta_cfg)
525
+ logits = model(state, t)
526
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), ids.reshape(-1))
527
+ optimizer.zero_grad(set_to_none=True)
528
+ loss.backward()
529
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
530
+ optimizer.step()
531
+ if is_main_process(rank) and (step % args.log_interval == 0 or step == 1):
532
+ print(f"[lta] step={step} loss={loss.item():.4f} elapsed={time.time() - t0:.1f}s", flush=True)
533
+ if step % args.eval_interval == 0 or step == args.steps:
534
+ if is_ddp:
535
+ dist.barrier()
536
+ if is_main_process(rank):
537
+ val_loss = estimate_lta_loss(raw_model, val_data, args, lta_cfg, device, args.eval_iters)
538
+ sample = generate_lta(
539
+ raw_model,
540
+ tokenizer,
541
+ length=min(args.sample_len, args.block_size),
542
+ steps=args.decode_steps,
543
+ c_min=args.c_min,
544
+ c_max=args.decode_c_max,
545
+ temp=args.endpoint_temp,
546
+ device=device,
547
+ )
548
+ row = {"step": step, "train_loss": float(loss.item()), "val_loss": val_loss, **text_stats(sample)}
549
+ with log_path.open("a", encoding="utf-8") as f:
550
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
551
+ (out_dir / f"sample_step{step:05d}.txt").write_text(sample, encoding="utf-8")
552
+ save_checkpoint(out_dir / "latest.pt", raw_model, optimizer, step, cfg, {"mode": "lta", "data_mode": args.data_mode, "lta_config": asdict(lta_cfg), "tokenizer": tokenizer.to_json()})
553
+ if is_ddp:
554
+ dist.barrier()
555
+ if is_main_process(rank):
556
+ save_checkpoint(out_dir / f"step_{args.steps:05d}.pt", raw_model, optimizer, args.steps, cfg, {"mode": "lta", "data_mode": args.data_mode, "lta_config": asdict(lta_cfg), "tokenizer": tokenizer.to_json()})
557
+
558
+
559
+ def main() -> None:
560
+ p = argparse.ArgumentParser()
561
+ p.add_argument("--mode", choices=["ar", "fully_coupled", "lta", "both"], default="both")
562
+ p.add_argument("--data_dir", default="experiments/nanogpt_tinyshakespeare_char/data")
563
+ p.add_argument("--out_dir", default="experiments/nanogpt_tinyshakespeare_char/runs/char_5k")
564
+ p.add_argument("--device", default="auto")
565
+ p.add_argument("--steps", type=int, default=5000)
566
+ p.add_argument("--data_mode", choices=["random", "stream"], default="random")
567
+ p.add_argument("--batch_size", type=int, default=64)
568
+ p.add_argument("--block_size", type=int, default=128)
569
+ p.add_argument("--n_layer", type=int, default=4)
570
+ p.add_argument("--n_head", type=int, default=4)
571
+ p.add_argument("--n_embd", type=int, default=128)
572
+ p.add_argument("--dropout", type=float, default=0.1)
573
+ p.add_argument("--lr", type=float, default=3e-4)
574
+ p.add_argument("--weight_decay", type=float, default=0.1)
575
+ p.add_argument("--grad_clip", type=float, default=1.0)
576
+ p.add_argument("--log_interval", type=int, default=100)
577
+ p.add_argument("--eval_interval", type=int, default=500)
578
+ p.add_argument("--eval_iters", type=int, default=20)
579
+ p.add_argument("--sample_len", type=int, default=512)
580
+ p.add_argument("--ar_temp", type=float, default=0.8)
581
+ p.add_argument("--c_min", type=float, default=1.0)
582
+ p.add_argument("--c_max", type=float, default=64.0)
583
+ p.add_argument("--decode_c_max", type=float, default=16.0)
584
+ p.add_argument("--endpoint_temp", type=float, default=1.3)
585
+ p.add_argument("--decode_steps", type=int, default=256)
586
+ p.add_argument("--t_mode", choices=["same", "independent"], default="same")
587
+ p.add_argument("--resume_path", default="")
588
+ p.add_argument("--seed", type=int, default=1337)
589
+ args = p.parse_args()
590
+
591
+ device, rank, local_rank, world_size, is_ddp = setup_distributed(args.device)
592
+ torch.manual_seed(args.seed + rank)
593
+ if is_main_process(rank):
594
+ print(f"[setup] device={device} rank={rank} world_size={world_size}", flush=True)
595
+ ensure_tinyshakespeare(Path(args.data_dir))
596
+ if is_ddp:
597
+ dist.barrier()
598
+ text, tokenizer, train_data, val_data = load_tinyshakespeare(Path(args.data_dir))
599
+ out_dir = Path(args.out_dir)
600
+ if is_main_process(rank):
601
+ out_dir.mkdir(parents=True, exist_ok=True)
602
+ (out_dir / "tokenizer.json").write_text(json.dumps(tokenizer.to_json(), ensure_ascii=False, indent=2), encoding="utf-8")
603
+ (out_dir / "args.json").write_text(json.dumps(vars(args), ensure_ascii=False, indent=2), encoding="utf-8")
604
+ print(f"[data] chars={len(text)} vocab={tokenizer.vocab_size} train={len(train_data)} val={len(val_data)}", flush=True)
605
+ if is_ddp:
606
+ dist.barrier()
607
+
608
+ try:
609
+ if args.mode in {"ar", "both"}:
610
+ train_ar(args, tokenizer, train_data, val_data, device, rank=rank, local_rank=local_rank, is_ddp=is_ddp)
611
+ if args.mode in {"fully_coupled", "lta", "both"}:
612
+ train_lta(args, tokenizer, train_data, val_data, device, rank=rank, local_rank=local_rank, is_ddp=is_ddp)
613
+ finally:
614
+ cleanup_distributed(is_ddp)
615
+
616
+
617
+ if __name__ == "__main__":
618
+ main()
LTA_openwebtext_dualt/logs/elfopt_4gpu_debug_20260513/lta_owt_fast10k_len1024_elfopt_muon_ema_ddit768x12_4gpu_5epoch_20260513.log ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ *****************************************
3
+ Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
4
+ *****************************************
5
+ [rank0]:[W513 02:20:46.100148121 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
6
+ NCCL version 2.25.1+cuda12.8
7
+ [rank3]:[W513 02:20:46.141519466 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
8
+ [rank1]:[W513 02:20:46.143177080 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
9
+ [rank2]:[W513 02:20:46.172962616 ProcessGroupNCCL.cpp:4571] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
10
+ {
11
+ "device": "cuda:0",
12
+ "rank": 0,
13
+ "world_size": 4,
14
+ "samples": "owt_cached_chunks:10904",
15
+ "vocab_size": 50257,
16
+ "tokenizer_vocab_size": 50257,
17
+ "save_dir": "runs/lta_owt_fast10k_len1024_elfopt_muon_ema_ddit768x12_4gpu_5epoch_20260513",
18
+ "batch_size": 8,
19
+ "grad_accum": 16,
20
+ "effective_batch_size": 512,
21
+ "global_batch_size": 512,
22
+ "lr_schedule": "constant_warmup",
23
+ "optimizer": "muon",
24
+ "warmup_steps": 11,
25
+ "min_lr": 0.0,
26
+ "weight_decay": 0.0,
27
+ "adamw_param_groups": "nanogpt",
28
+ "adam_beta1": 0.9,
29
+ "adam_beta2": 0.95,
30
+ "adam_eps": 1e-08,
31
+ "muon_momentum": 0.95,
32
+ "muon_ns_steps": 5,
33
+ "muon_update_scale": 1.0,
34
+ "ema_decay": 0.9999,
35
+ "ema_start_step": 0,
36
+ "model_type": "ddit",
37
+ "dual_t": true,
38
+ "corrupt_t_mode": "independent",
39
+ "corrupt_min_t": null,
40
+ "corrupt_max_t": null,
41
+ "prefix_block_prob": 0.0,
42
+ "prefix_block_len": 128,
43
+ "dirichlet_endpoint_mode": "categorical_dual_t",
44
+ "dirichlet_semantic_t_mode": "same",
45
+ "dirichlet_semantic_t_value": 0.0,
46
+ "categorical_wrong_from_full_vocab": true,
47
+ "categorical_wrong_from_batch_valid_tokens": false,
48
+ "mask_mixture_original_prob": 0.0,
49
+ "mask_mixture_lowk_prob": 0.0,
50
+ "mask_mixture_lowcorrupt_prob": 0.0,
51
+ "mask_mixture_block_prob": 0.0,
52
+ "mask_mixture_all_prob": 0.0,
53
+ "mask_mixture_lowk_clean_tokens": "1,2,4,8,16,32,64",
54
+ "mask_mixture_lowcorrupt_tokens": "1,2,4,8,16,32,64",
55
+ "mask_mixture_block_tokens": "64,128",
56
+ "simplex_bridge_sampler": "dirichlet",
57
+ "logistic_normal_sigma_min": 0.18,
58
+ "logistic_normal_sigma_max": 2.2,
59
+ "logistic_normal_tau_min": 0.65,
60
+ "logistic_normal_tau_max": 1.15,
61
+ "torch_compile": false,
62
+ "compile_mode": "max-autotune",
63
+ "state_format": "prob",
64
+ "target_loss": "hard_ce",
65
+ "meanflow_weight": 0.0,
66
+ "bridge_noise_init": "logistic_normal",
67
+ "noise_sigma": -1.0,
68
+ "wrap": true,
69
+ "wrap_mode": "stream",
70
+ "wrap_record_buffer_size": 200,
71
+ "owt_cached_chunks": true,
72
+ "owt_chunk_cache_dir": "/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k_fast10k",
73
+ "owt_chunk_cache_rebuild": false,
74
+ "owt_chunk_cache_write_batch": 4096,
75
+ "owt_exact_repeat_per_chunk": 0,
76
+ "online_chunk_shuffle": false,
77
+ "online_chunk_shuffle_buffer": 10000,
78
+ "openwebtext_split": "all",
79
+ "detokenizer": "auto",
80
+ "resolved_detokenizer": null,
81
+ "num_workers": 0,
82
+ "latest_every": 25,
83
+ "resume_path": ""
84
+ }
85
+ step=10 micro_steps=160 elapsed=35.3s lr=2.000000e-03 loss_all=10.7819 acc_all=0.5062 loss_corrupt=10.7918 acc_corrupt=0.3523 corrupt_frac=0.5581 loss=10.7918 loss_recon=10.7918 loss_meanflow=0.0000 mean_model_t=0.5021 mean_corrupt_t=0.5155 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4797 init_acc_corrupt=0.4881 init_gold_top10=0.5148 init_gold_top100=0.5444
86
+ step=20 micro_steps=320 elapsed=43.2s lr=2.000000e-03 loss_all=10.5714 acc_all=0.5577 loss_corrupt=10.6476 acc_corrupt=0.3828 corrupt_frac=0.5535 loss=10.6476 loss_recon=10.6476 loss_meanflow=0.0000 mean_model_t=0.4831 mean_corrupt_t=0.5005 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4958 init_acc_corrupt=0.4692 init_gold_top10=0.4982 init_gold_top100=0.5296
87
+ step=30 micro_steps=480 elapsed=49.2s lr=2.000000e-03 loss_all=10.2797 acc_all=0.5418 loss_corrupt=10.4402 acc_corrupt=0.3646 corrupt_frac=0.5515 loss=10.4402 loss_recon=10.4402 loss_meanflow=0.0000 mean_model_t=0.5005 mean_corrupt_t=0.4971 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.5053 init_acc_corrupt=0.4602 init_gold_top10=0.4890 init_gold_top100=0.5193
88
+ step=40 micro_steps=640 elapsed=47.3s lr=2.000000e-03 loss_all=9.9466 acc_all=0.5278 loss_corrupt=10.1895 acc_corrupt=0.3629 corrupt_frac=0.5560 loss=10.1895 loss_recon=10.1895 loss_meanflow=0.0000 mean_model_t=0.4889 mean_corrupt_t=0.5037 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4913 init_acc_corrupt=0.4763 init_gold_top10=0.5031 init_gold_top100=0.5324
89
+ step=50 micro_steps=800 elapsed=48.1s lr=2.000000e-03 loss_all=9.5849 acc_all=0.5083 loss_corrupt=9.9274 acc_corrupt=0.3452 corrupt_frac=0.5483 loss=9.9274 loss_recon=9.9274 loss_meanflow=0.0000 mean_model_t=0.5005 mean_corrupt_t=0.5071 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4913 init_acc_corrupt=0.4737 init_gold_top10=0.5031 init_gold_top100=0.5327
90
+ step=60 micro_steps=960 elapsed=52.1s lr=2.000000e-03 loss_all=9.2104 acc_all=0.4910 loss_corrupt=9.6379 acc_corrupt=0.3375 corrupt_frac=0.5677 loss=9.6379 loss_recon=9.6379 loss_meanflow=0.0000 mean_model_t=0.5078 mean_corrupt_t=0.4997 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4952 init_acc_corrupt=0.4702 init_gold_top10=0.4995 init_gold_top100=0.5282
91
+ step=70 micro_steps=1120 elapsed=49.0s lr=2.000000e-03 loss_all=8.7828 acc_all=0.4820 loss_corrupt=9.3226 acc_corrupt=0.3293 corrupt_frac=0.5563 loss=9.3226 loss_recon=9.3226 loss_meanflow=0.0000 mean_model_t=0.5141 mean_corrupt_t=0.5078 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4938 init_acc_corrupt=0.4720 init_gold_top10=0.5007 init_gold_top100=0.5290
92
+ step=80 micro_steps=1280 elapsed=51.9s lr=2.000000e-03 loss_all=8.3273 acc_all=0.4771 loss_corrupt=8.9524 acc_corrupt=0.3331 corrupt_frac=0.5579 loss=8.9524 loss_recon=8.9524 loss_meanflow=0.0000 mean_model_t=0.5173 mean_corrupt_t=0.5132 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4842 init_acc_corrupt=0.4812 init_gold_top10=0.5102 init_gold_top100=0.5391
93
+ step=90 micro_steps=1440 elapsed=49.3s lr=2.000000e-03 loss_all=7.8580 acc_all=0.4804 loss_corrupt=8.5915 acc_corrupt=0.3368 corrupt_frac=0.5610 loss=8.5915 loss_recon=8.5915 loss_meanflow=0.0000 mean_model_t=0.5097 mean_corrupt_t=0.5138 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4881 init_acc_corrupt=0.4782 init_gold_top10=0.5062 init_gold_top100=0.5356
94
+ step=100 micro_steps=1600 elapsed=49.2s lr=2.000000e-03 loss_all=7.3653 acc_all=0.4879 loss_corrupt=8.2388 acc_corrupt=0.3383 corrupt_frac=0.5443 loss=8.2388 loss_recon=8.2388 loss_meanflow=0.0000 mean_model_t=0.4959 mean_corrupt_t=0.5087 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 wrong_frac=0.4917 init_acc_corrupt=0.4737 init_gold_top10=0.5032 init_gold_top100=0.5311
LTA_openwebtext_dualt/logs/elfopt_4gpu_debug_20260513/lta_owt_fast10k_len1024_elfopt_muon_ema_ddit768x12_4gpu_5epoch_20260513_trace.log ADDED
@@ -0,0 +1 @@
 
 
1
+ {"out_json": "docs/lta_samples/metrics_20260513/lta_owt_fast10k_len1024_elfopt_muon_ema_ddit768x12_4gpu_5epoch_20260513/trace_latest_ema_steps64_c48_t1p45.json", "records": 10, "step": 107}
LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/infer_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_step_0010000_t1p45.log ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [watch-classic] 2026-05-21_01:37:41 infer runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0010000.pt -> docs/lta_samples/metrics_20260520/lm1b_classic_repro_every10k_normal_steps_state_t1p45_c1024_n1024/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0010000
2
+ [ckpt] runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0010000.pt step=10000
3
+ [decode] steps128_c1024_t1p45 generated 16/1024
4
+ [decode] steps128_c1024_t1p45 generated 32/1024
5
+ [decode] steps128_c1024_t1p45 generated 48/1024
6
+ [decode] steps128_c1024_t1p45 generated 64/1024
7
+ [decode] steps128_c1024_t1p45 generated 80/1024
8
+ [decode] steps128_c1024_t1p45 generated 96/1024
9
+ [decode] steps128_c1024_t1p45 generated 112/1024
10
+ [decode] steps128_c1024_t1p45 generated 128/1024
11
+ [decode] steps128_c1024_t1p45 generated 144/1024
12
+ [decode] steps128_c1024_t1p45 generated 160/1024
13
+ [decode] steps128_c1024_t1p45 generated 176/1024
14
+ [decode] steps128_c1024_t1p45 generated 192/1024
15
+ [decode] steps128_c1024_t1p45 generated 208/1024
16
+ [decode] steps128_c1024_t1p45 generated 224/1024
17
+ [decode] steps128_c1024_t1p45 generated 240/1024
18
+ [decode] steps128_c1024_t1p45 generated 256/1024
19
+ [decode] steps128_c1024_t1p45 generated 272/1024
20
+ [decode] steps128_c1024_t1p45 generated 288/1024
21
+ [decode] steps128_c1024_t1p45 generated 304/1024
22
+ [decode] steps128_c1024_t1p45 generated 320/1024
23
+ [decode] steps128_c1024_t1p45 generated 336/1024
24
+ [decode] steps128_c1024_t1p45 generated 352/1024
25
+ [decode] steps128_c1024_t1p45 generated 368/1024
26
+ [decode] steps128_c1024_t1p45 generated 384/1024
27
+ [decode] steps128_c1024_t1p45 generated 400/1024
28
+ [decode] steps128_c1024_t1p45 generated 416/1024
29
+ [decode] steps128_c1024_t1p45 generated 432/1024
30
+ [decode] steps128_c1024_t1p45 generated 448/1024
31
+ [decode] steps128_c1024_t1p45 generated 464/1024
32
+ [decode] steps128_c1024_t1p45 generated 480/1024
33
+ [decode] steps128_c1024_t1p45 generated 496/1024
34
+ [decode] steps128_c1024_t1p45 generated 512/1024
35
+ [decode] steps128_c1024_t1p45 generated 528/1024
36
+ [decode] steps128_c1024_t1p45 generated 544/1024
37
+ [decode] steps128_c1024_t1p45 generated 560/1024
38
+ [decode] steps128_c1024_t1p45 generated 576/1024
39
+ [decode] steps128_c1024_t1p45 generated 592/1024
40
+ [decode] steps128_c1024_t1p45 generated 608/1024
41
+ [decode] steps128_c1024_t1p45 generated 624/1024
42
+ [decode] steps128_c1024_t1p45 generated 640/1024
43
+ [decode] steps128_c1024_t1p45 generated 656/1024
44
+ [decode] steps128_c1024_t1p45 generated 672/1024
45
+ [decode] steps128_c1024_t1p45 generated 688/1024
46
+ [decode] steps128_c1024_t1p45 generated 704/1024
47
+ [decode] steps128_c1024_t1p45 generated 720/1024
48
+ [decode] steps128_c1024_t1p45 generated 736/1024
49
+ [decode] steps128_c1024_t1p45 generated 752/1024
50
+ [decode] steps128_c1024_t1p45 generated 768/1024
51
+ [decode] steps128_c1024_t1p45 generated 784/1024
52
+ [decode] steps128_c1024_t1p45 generated 800/1024
53
+ [decode] steps128_c1024_t1p45 generated 816/1024
54
+ [decode] steps128_c1024_t1p45 generated 832/1024
55
+ [decode] steps128_c1024_t1p45 generated 848/1024
56
+ [decode] steps128_c1024_t1p45 generated 864/1024
57
+ [decode] steps128_c1024_t1p45 generated 880/1024
58
+ [decode] steps128_c1024_t1p45 generated 896/1024
59
+ [decode] steps128_c1024_t1p45 generated 912/1024
60
+ [decode] steps128_c1024_t1p45 generated 928/1024
61
+ [decode] steps128_c1024_t1p45 generated 944/1024
62
+ [decode] steps128_c1024_t1p45 generated 960/1024
63
+ [decode] steps128_c1024_t1p45 generated 976/1024
64
+ [decode] steps128_c1024_t1p45 generated 992/1024
65
+ [decode] steps128_c1024_t1p45 generated 1008/1024
66
+ [decode] steps128_c1024_t1p45 generated 1024/1024
67
+ [summary] {"name": "steps128_c1024_t1p45", "step": 10000, "decode_steps": 128, "concentration_max": 1024.0, "raw_genppl": 30.13676440721101, "stripped_genppl": 34.68526771869161, "sample_entropy": 3.465849114229341, "distinct_1": 0.03083038330078125, "distinct_2": 0.19376691683070865, "top_token_mass": 0.232330322265625, "raw_kept": 1024, "stripped_kept": 1024}
68
+ [watch-classic] 2026-05-21_01:46:53 done step_0010000
LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/infer_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_step_0020000_t1p45.log ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [watch-classic] 2026-05-21_04:16:54 infer runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0020000.pt -> docs/lta_samples/metrics_20260520/lm1b_classic_repro_every10k_normal_steps_state_t1p45_c1024_n1024/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0020000
2
+ [ckpt] runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0020000.pt step=20000
3
+ [decode] steps128_c1024_t1p45 generated 16/1024
4
+ [decode] steps128_c1024_t1p45 generated 32/1024
5
+ [decode] steps128_c1024_t1p45 generated 48/1024
6
+ [decode] steps128_c1024_t1p45 generated 64/1024
7
+ [decode] steps128_c1024_t1p45 generated 80/1024
8
+ [decode] steps128_c1024_t1p45 generated 96/1024
9
+ [decode] steps128_c1024_t1p45 generated 112/1024
10
+ [decode] steps128_c1024_t1p45 generated 128/1024
11
+ [decode] steps128_c1024_t1p45 generated 144/1024
12
+ [decode] steps128_c1024_t1p45 generated 160/1024
13
+ [decode] steps128_c1024_t1p45 generated 176/1024
14
+ [decode] steps128_c1024_t1p45 generated 192/1024
15
+ [decode] steps128_c1024_t1p45 generated 208/1024
16
+ [decode] steps128_c1024_t1p45 generated 224/1024
17
+ [decode] steps128_c1024_t1p45 generated 240/1024
18
+ [decode] steps128_c1024_t1p45 generated 256/1024
19
+ [decode] steps128_c1024_t1p45 generated 272/1024
20
+ [decode] steps128_c1024_t1p45 generated 288/1024
21
+ [decode] steps128_c1024_t1p45 generated 304/1024
22
+ [decode] steps128_c1024_t1p45 generated 320/1024
23
+ [decode] steps128_c1024_t1p45 generated 336/1024
24
+ [decode] steps128_c1024_t1p45 generated 352/1024
25
+ [decode] steps128_c1024_t1p45 generated 368/1024
26
+ [decode] steps128_c1024_t1p45 generated 384/1024
27
+ [decode] steps128_c1024_t1p45 generated 400/1024
28
+ [decode] steps128_c1024_t1p45 generated 416/1024
29
+ [decode] steps128_c1024_t1p45 generated 432/1024
30
+ [decode] steps128_c1024_t1p45 generated 448/1024
31
+ [decode] steps128_c1024_t1p45 generated 464/1024
32
+ [decode] steps128_c1024_t1p45 generated 480/1024
33
+ [decode] steps128_c1024_t1p45 generated 496/1024
34
+ [decode] steps128_c1024_t1p45 generated 512/1024
35
+ [decode] steps128_c1024_t1p45 generated 528/1024
36
+ [decode] steps128_c1024_t1p45 generated 544/1024
37
+ [decode] steps128_c1024_t1p45 generated 560/1024
38
+ [decode] steps128_c1024_t1p45 generated 576/1024
39
+ [decode] steps128_c1024_t1p45 generated 592/1024
40
+ [decode] steps128_c1024_t1p45 generated 608/1024
41
+ [decode] steps128_c1024_t1p45 generated 624/1024
42
+ [decode] steps128_c1024_t1p45 generated 640/1024
43
+ [decode] steps128_c1024_t1p45 generated 656/1024
44
+ [decode] steps128_c1024_t1p45 generated 672/1024
45
+ [decode] steps128_c1024_t1p45 generated 688/1024
46
+ [decode] steps128_c1024_t1p45 generated 704/1024
47
+ [decode] steps128_c1024_t1p45 generated 720/1024
48
+ [decode] steps128_c1024_t1p45 generated 736/1024
49
+ [decode] steps128_c1024_t1p45 generated 752/1024
50
+ [decode] steps128_c1024_t1p45 generated 768/1024
51
+ [decode] steps128_c1024_t1p45 generated 784/1024
52
+ [decode] steps128_c1024_t1p45 generated 800/1024
53
+ [decode] steps128_c1024_t1p45 generated 816/1024
54
+ [decode] steps128_c1024_t1p45 generated 832/1024
55
+ [decode] steps128_c1024_t1p45 generated 848/1024
56
+ [decode] steps128_c1024_t1p45 generated 864/1024
57
+ [decode] steps128_c1024_t1p45 generated 880/1024
58
+ [decode] steps128_c1024_t1p45 generated 896/1024
59
+ [decode] steps128_c1024_t1p45 generated 912/1024
60
+ [decode] steps128_c1024_t1p45 generated 928/1024
61
+ [decode] steps128_c1024_t1p45 generated 944/1024
62
+ [decode] steps128_c1024_t1p45 generated 960/1024
63
+ [decode] steps128_c1024_t1p45 generated 976/1024
64
+ [decode] steps128_c1024_t1p45 generated 992/1024
65
+ [decode] steps128_c1024_t1p45 generated 1008/1024
66
+ [decode] steps128_c1024_t1p45 generated 1024/1024
67
+ [summary] {"name": "steps128_c1024_t1p45", "step": 20000, "decode_steps": 128, "concentration_max": 1024.0, "raw_genppl": 46.858096679153796, "stripped_genppl": 50.24137870832906, "sample_entropy": 3.638899175986386, "distinct_1": 0.0443572998046875, "distinct_2": 0.28157295767716534, "top_token_mass": 0.1688232421875, "raw_kept": 1024, "stripped_kept": 1024}
68
+ [watch-classic] 2026-05-21_04:26:05 done step_0020000
LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/infer_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_step_0030000_t1p45.log ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [watch-classic] 2026-05-21_06:55:06 infer runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0030000.pt -> docs/lta_samples/metrics_20260520/lm1b_classic_repro_every10k_normal_steps_state_t1p45_c1024_n1024/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0030000
2
+ [ckpt] runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0030000.pt step=30000
3
+ [decode] steps128_c1024_t1p45 generated 16/1024
4
+ [decode] steps128_c1024_t1p45 generated 32/1024
5
+ [decode] steps128_c1024_t1p45 generated 48/1024
6
+ [decode] steps128_c1024_t1p45 generated 64/1024
7
+ [decode] steps128_c1024_t1p45 generated 80/1024
8
+ [decode] steps128_c1024_t1p45 generated 96/1024
9
+ [decode] steps128_c1024_t1p45 generated 112/1024
10
+ [decode] steps128_c1024_t1p45 generated 128/1024
11
+ [decode] steps128_c1024_t1p45 generated 144/1024
12
+ [decode] steps128_c1024_t1p45 generated 160/1024
13
+ [decode] steps128_c1024_t1p45 generated 176/1024
14
+ [decode] steps128_c1024_t1p45 generated 192/1024
15
+ [decode] steps128_c1024_t1p45 generated 208/1024
16
+ [decode] steps128_c1024_t1p45 generated 224/1024
17
+ [decode] steps128_c1024_t1p45 generated 240/1024
18
+ [decode] steps128_c1024_t1p45 generated 256/1024
19
+ [decode] steps128_c1024_t1p45 generated 272/1024
20
+ [decode] steps128_c1024_t1p45 generated 288/1024
21
+ [decode] steps128_c1024_t1p45 generated 304/1024
22
+ [decode] steps128_c1024_t1p45 generated 320/1024
23
+ [decode] steps128_c1024_t1p45 generated 336/1024
24
+ [decode] steps128_c1024_t1p45 generated 352/1024
25
+ [decode] steps128_c1024_t1p45 generated 368/1024
26
+ [decode] steps128_c1024_t1p45 generated 384/1024
27
+ [decode] steps128_c1024_t1p45 generated 400/1024
28
+ [decode] steps128_c1024_t1p45 generated 416/1024
29
+ [decode] steps128_c1024_t1p45 generated 432/1024
30
+ [decode] steps128_c1024_t1p45 generated 448/1024
31
+ [decode] steps128_c1024_t1p45 generated 464/1024
32
+ [decode] steps128_c1024_t1p45 generated 480/1024
33
+ [decode] steps128_c1024_t1p45 generated 496/1024
34
+ [decode] steps128_c1024_t1p45 generated 512/1024
35
+ [decode] steps128_c1024_t1p45 generated 528/1024
36
+ [decode] steps128_c1024_t1p45 generated 544/1024
37
+ [decode] steps128_c1024_t1p45 generated 560/1024
38
+ [decode] steps128_c1024_t1p45 generated 576/1024
39
+ [decode] steps128_c1024_t1p45 generated 592/1024
40
+ [decode] steps128_c1024_t1p45 generated 608/1024
41
+ [decode] steps128_c1024_t1p45 generated 624/1024
42
+ [decode] steps128_c1024_t1p45 generated 640/1024
43
+ [decode] steps128_c1024_t1p45 generated 656/1024
44
+ [decode] steps128_c1024_t1p45 generated 672/1024
45
+ [decode] steps128_c1024_t1p45 generated 688/1024
46
+ [decode] steps128_c1024_t1p45 generated 704/1024
47
+ [decode] steps128_c1024_t1p45 generated 720/1024
48
+ [decode] steps128_c1024_t1p45 generated 736/1024
49
+ [decode] steps128_c1024_t1p45 generated 752/1024
50
+ [decode] steps128_c1024_t1p45 generated 768/1024
51
+ [decode] steps128_c1024_t1p45 generated 784/1024
52
+ [decode] steps128_c1024_t1p45 generated 800/1024
53
+ [decode] steps128_c1024_t1p45 generated 816/1024
54
+ [decode] steps128_c1024_t1p45 generated 832/1024
55
+ [decode] steps128_c1024_t1p45 generated 848/1024
56
+ [decode] steps128_c1024_t1p45 generated 864/1024
57
+ [decode] steps128_c1024_t1p45 generated 880/1024
58
+ [decode] steps128_c1024_t1p45 generated 896/1024
59
+ [decode] steps128_c1024_t1p45 generated 912/1024
60
+ [decode] steps128_c1024_t1p45 generated 928/1024
61
+ [decode] steps128_c1024_t1p45 generated 944/1024
62
+ [decode] steps128_c1024_t1p45 generated 960/1024
63
+ [decode] steps128_c1024_t1p45 generated 976/1024
64
+ [decode] steps128_c1024_t1p45 generated 992/1024
65
+ [decode] steps128_c1024_t1p45 generated 1008/1024
66
+ [decode] steps128_c1024_t1p45 generated 1024/1024
67
+ [summary] {"name": "steps128_c1024_t1p45", "step": 30000, "decode_steps": 128, "concentration_max": 1024.0, "raw_genppl": 42.11245193528327, "stripped_genppl": 57.07543476214156, "sample_entropy": 4.026958025870078, "distinct_1": 0.04402923583984375, "distinct_2": 0.3001584030511811, "top_token_mass": 0.0693817138671875, "raw_kept": 1024, "stripped_kept": 1024}
68
+ [watch-classic] 2026-05-21_07:04:13 done step_0030000
LTA_openwebtext_dualt/logs/lm1b_classic_repro_infer_watch/processed_lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005_steps128_c1024_t1p45_n1024.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0010000.pt
2
+ runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0020000.pt
3
+ runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0030000.pt
4
+ runs/lta_lm1b_classic_c1024_fullvocab_len128_repro_save10k_gbs512_4gpu_1m_20260520_231005/step_0040000.pt
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/activate.bat ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @REM Copyright (c) 2020-202x The virtualenv developers
2
+ @REM
3
+ @REM Permission is hereby granted, free of charge, to any person obtaining
4
+ @REM a copy of this software and associated documentation files (the
5
+ @REM "Software"), to deal in the Software without restriction, including
6
+ @REM without limitation the rights to use, copy, modify, merge, publish,
7
+ @REM distribute, sublicense, and/or sell copies of the Software, and to
8
+ @REM permit persons to whom the Software is furnished to do so, subject to
9
+ @REM the following conditions:
10
+ @REM
11
+ @REM The above copyright notice and this permission notice shall be
12
+ @REM included in all copies or substantial portions of the Software.
13
+ @REM
14
+ @REM THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15
+ @REM EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16
+ @REM MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17
+ @REM NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18
+ @REM LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19
+ @REM OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20
+ @REM WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21
+
22
+ @REM This file is UTF-8 encoded, so we need to update the current code page while executing it
23
+ @for /f "tokens=2 delims=:." %%a in ('"%SystemRoot%\System32\chcp.com"') do @set _OLD_CODEPAGE=%%a
24
+
25
+ @if defined _OLD_CODEPAGE (
26
+ "%SystemRoot%\System32\chcp.com" 65001 > nul
27
+ )
28
+
29
+ @for %%i in ("/e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv") do @set "VIRTUAL_ENV=%%~fi"
30
+
31
+ @set "VIRTUAL_ENV_PROMPT="
32
+ @if NOT DEFINED VIRTUAL_ENV_PROMPT (
33
+ @for %%d in ("%VIRTUAL_ENV%") do @set "VIRTUAL_ENV_PROMPT=%%~nxd"
34
+ )
35
+
36
+ @if defined _OLD_VIRTUAL_PROMPT (
37
+ @set "PROMPT=%_OLD_VIRTUAL_PROMPT%"
38
+ ) else (
39
+ @if not defined PROMPT (
40
+ @set "PROMPT=$P$G"
41
+ )
42
+ @if not defined VIRTUAL_ENV_DISABLE_PROMPT (
43
+ @set "_OLD_VIRTUAL_PROMPT=%PROMPT%"
44
+ )
45
+ )
46
+ @if not defined VIRTUAL_ENV_DISABLE_PROMPT (
47
+ @set "PROMPT=(%VIRTUAL_ENV_PROMPT%) %PROMPT%"
48
+ )
49
+
50
+ @REM Don't use () to avoid problems with them in %PATH%
51
+ @if defined _OLD_VIRTUAL_PYTHONHOME @goto ENDIFVHOME
52
+ @set "_OLD_VIRTUAL_PYTHONHOME=%PYTHONHOME%"
53
+ :ENDIFVHOME
54
+
55
+ @set PYTHONHOME=
56
+
57
+ @REM if defined _OLD_VIRTUAL_PATH (
58
+ @if not defined _OLD_VIRTUAL_PATH @goto ENDIFVPATH1
59
+ @set "PATH=%_OLD_VIRTUAL_PATH%"
60
+ :ENDIFVPATH1
61
+ @REM ) else (
62
+ @if defined _OLD_VIRTUAL_PATH @goto ENDIFVPATH2
63
+ @set "_OLD_VIRTUAL_PATH=%PATH%"
64
+ :ENDIFVPATH2
65
+
66
+ @set "PATH=%VIRTUAL_ENV%\bin;%PATH%"
67
+
68
+ @if defined _OLD_CODEPAGE (
69
+ "%SystemRoot%\System32\chcp.com" %_OLD_CODEPAGE% > nul
70
+ @set _OLD_CODEPAGE=
71
+ )
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/activate.fish ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-202x The virtualenv developers
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining
4
+ # a copy of this software and associated documentation files (the
5
+ # "Software"), to deal in the Software without restriction, including
6
+ # without limitation the rights to use, copy, modify, merge, publish,
7
+ # distribute, sublicense, and/or sell copies of the Software, and to
8
+ # permit persons to whom the Software is furnished to do so, subject to
9
+ # the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be
12
+ # included in all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15
+ # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16
+ # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17
+ # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18
+ # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19
+ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20
+ # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21
+
22
+ # This file must be used using `source bin/activate.fish` *within a running fish ( http://fishshell.com ) session*.
23
+ # Do not run it directly.
24
+
25
+ function _bashify_path -d "Converts a fish path to something bash can recognize"
26
+ set fishy_path $argv
27
+ set bashy_path $fishy_path[1]
28
+ for path_part in $fishy_path[2..-1]
29
+ set bashy_path "$bashy_path:$path_part"
30
+ end
31
+ echo $bashy_path
32
+ end
33
+
34
+ function _fishify_path -d "Converts a bash path to something fish can recognize"
35
+ echo $argv | tr ':' '\n'
36
+ end
37
+
38
+ function deactivate -d 'Exit virtualenv mode and return to the normal environment.'
39
+ # reset old environment variables
40
+ if test -n "$_OLD_VIRTUAL_PATH"
41
+ # https://github.com/fish-shell/fish-shell/issues/436 altered PATH handling
42
+ if test (string sub -s 1 -l 1 $FISH_VERSION) -lt 3
43
+ set -gx PATH (_fishify_path "$_OLD_VIRTUAL_PATH")
44
+ else
45
+ set -gx PATH $_OLD_VIRTUAL_PATH
46
+ end
47
+ set -e _OLD_VIRTUAL_PATH
48
+ end
49
+
50
+ if test -n "$_OLD_VIRTUAL_PYTHONHOME"
51
+ set -gx PYTHONHOME "$_OLD_VIRTUAL_PYTHONHOME"
52
+ set -e _OLD_VIRTUAL_PYTHONHOME
53
+ end
54
+
55
+ if test -n "$_OLD_FISH_PROMPT_OVERRIDE"
56
+ and functions -q _old_fish_prompt
57
+ # Set an empty local `$fish_function_path` to allow the removal of `fish_prompt` using `functions -e`.
58
+ set -l fish_function_path
59
+
60
+ # Erase virtualenv's `fish_prompt` and restore the original.
61
+ functions -e fish_prompt
62
+ functions -c _old_fish_prompt fish_prompt
63
+ functions -e _old_fish_prompt
64
+ set -e _OLD_FISH_PROMPT_OVERRIDE
65
+ end
66
+
67
+ set -e VIRTUAL_ENV
68
+ set -e VIRTUAL_ENV_PROMPT
69
+
70
+ if test "$argv[1]" != 'nondestructive'
71
+ # Self-destruct!
72
+ functions -e pydoc
73
+ functions -e deactivate
74
+ functions -e _bashify_path
75
+ functions -e _fishify_path
76
+ end
77
+ end
78
+
79
+ # Unset irrelevant variables.
80
+ deactivate nondestructive
81
+
82
+ set -gx VIRTUAL_ENV '/e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv'
83
+
84
+ # https://github.com/fish-shell/fish-shell/issues/436 altered PATH handling
85
+ if test (string sub -s 1 -l 1 $FISH_VERSION) -lt 3
86
+ set -gx _OLD_VIRTUAL_PATH (_bashify_path $PATH)
87
+ else
88
+ set -gx _OLD_VIRTUAL_PATH $PATH
89
+ end
90
+ set -gx PATH "$VIRTUAL_ENV"'/bin' $PATH
91
+
92
+ # Prompt override provided?
93
+ # If not, just use the environment name.
94
+ if test -n ''
95
+ set -gx VIRTUAL_ENV_PROMPT ''
96
+ else
97
+ set -gx VIRTUAL_ENV_PROMPT (basename "$VIRTUAL_ENV")
98
+ end
99
+
100
+ # Unset `$PYTHONHOME` if set.
101
+ if set -q PYTHONHOME
102
+ set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME
103
+ set -e PYTHONHOME
104
+ end
105
+
106
+ function pydoc
107
+ python -m pydoc $argv
108
+ end
109
+
110
+ if test -z "$VIRTUAL_ENV_DISABLE_PROMPT"
111
+ # Copy the current `fish_prompt` function as `_old_fish_prompt`.
112
+ functions -c fish_prompt _old_fish_prompt
113
+
114
+ function fish_prompt
115
+ # Run the user's prompt first; it might depend on (pipe)status.
116
+ set -l prompt (_old_fish_prompt)
117
+
118
+ printf '(%s) ' $VIRTUAL_ENV_PROMPT
119
+
120
+ string join -- \n $prompt # handle multi-line prompts
121
+ end
122
+
123
+ set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV"
124
+ end
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/f2py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import sys
4
+ from numpy.f2py.f2py2e import main
5
+ if __name__ == "__main__":
6
+ if sys.argv[0].endswith("-script.pyw"):
7
+ sys.argv[0] = sys.argv[0][:-11]
8
+ elif sys.argv[0].endswith(".exe"):
9
+ sys.argv[0] = sys.argv[0][:-4]
10
+ sys.exit(main())
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/bin/pydoc.bat ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @REM Copyright (c) 2020-202x The virtualenv developers
2
+ @REM
3
+ @REM Permission is hereby granted, free of charge, to any person obtaining
4
+ @REM a copy of this software and associated documentation files (the
5
+ @REM "Software"), to deal in the Software without restriction, including
6
+ @REM without limitation the rights to use, copy, modify, merge, publish,
7
+ @REM distribute, sublicense, and/or sell copies of the Software, and to
8
+ @REM permit persons to whom the Software is furnished to do so, subject to
9
+ @REM the following conditions:
10
+ @REM
11
+ @REM The above copyright notice and this permission notice shall be
12
+ @REM included in all copies or substantial portions of the Software.
13
+ @REM
14
+ @REM THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15
+ @REM EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16
+ @REM MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17
+ @REM NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18
+ @REM LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19
+ @REM OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20
+ @REM WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21
+
22
+ python.exe -m pydoc %*
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/audio_utils.py ADDED
@@ -0,0 +1,1254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
16
+ and remove unnecessary dependencies.
17
+ """
18
+
19
+ import base64
20
+ import importlib
21
+ import io
22
+ import os
23
+ import warnings
24
+ from collections.abc import Sequence
25
+ from io import BytesIO
26
+ from typing import TYPE_CHECKING, Any, Union
27
+
28
+ import httpx
29
+ import numpy as np
30
+ from packaging import version
31
+
32
+ from .utils import (
33
+ is_librosa_available,
34
+ is_numpy_array,
35
+ is_soundfile_available,
36
+ is_torch_tensor,
37
+ is_torchcodec_available,
38
+ requires_backends,
39
+ )
40
+ from .utils.generic import retry
41
+
42
+
43
+ if TYPE_CHECKING:
44
+ import torch
45
+
46
+ if is_soundfile_available():
47
+ import soundfile as sf
48
+
49
+ if is_librosa_available():
50
+ import librosa
51
+
52
+ # TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa
53
+ import soxr
54
+
55
+ if is_torchcodec_available():
56
+ TORCHCODEC_VERSION = version.parse(importlib.metadata.version("torchcodec"))
57
+
58
+ AudioInput = Union[np.ndarray, "torch.Tensor", Sequence[np.ndarray], Sequence["torch.Tensor"]]
59
+
60
+
61
+ @retry(exceptions=(httpx.HTTPError,))
62
+ def _fetch_audio_bytes(url: str, timeout: float | None = 10.0) -> bytes:
63
+ """Fetch audio bytes from a URL with automatic retry and exponential backoff."""
64
+ response = httpx.get(url, follow_redirects=True, timeout=timeout)
65
+ response.raise_for_status()
66
+ return response.content
67
+
68
+
69
+ def load_audio(audio: str | np.ndarray, sampling_rate=16000, timeout=None) -> np.ndarray:
70
+ """
71
+ Loads `audio` to an np.ndarray object.
72
+
73
+ Args:
74
+ audio (`str` or `np.ndarray`):
75
+ The audio to be loaded to the numpy array format.
76
+ sampling_rate (`int`, *optional*, defaults to 16000):
77
+ The sampling rate to be used when loading the audio. It should be same as the
78
+ sampling rate the model you will be using further was trained with.
79
+ timeout (`float`, *optional*):
80
+ The timeout value in seconds for the URL request.
81
+
82
+ Returns:
83
+ `np.ndarray`: A numpy array representing the audio.
84
+ """
85
+ if isinstance(audio, str):
86
+ # Try to load with `torchcodec` but do not enforce users to install it. If not found
87
+ # fallback to `librosa`. If using an audio-only model, most probably `torchcodec` won't be
88
+ # needed. Do not raise any errors if not installed or versions do not match
89
+ if is_torchcodec_available() and version.parse("0.3.0") <= TORCHCODEC_VERSION:
90
+ audio = load_audio_torchcodec(audio, sampling_rate=sampling_rate, timeout=timeout)
91
+ elif audio.rsplit("?", 1)[0].lower().endswith((".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv")):
92
+ raise RuntimeError(
93
+ f"The audio source appears to be a video file ('{audio.split('/')[-1]}'). "
94
+ "librosa cannot decode video containers. "
95
+ "Install torchcodec>=0.3.0 (`pip install torchcodec`) to load audio from video files."
96
+ )
97
+ else:
98
+ audio = load_audio_librosa(audio, sampling_rate=sampling_rate, timeout=timeout)
99
+ elif not isinstance(audio, np.ndarray):
100
+ raise TypeError(
101
+ "Incorrect format used for `audio`. Should be an url linking to an audio, a local path, or numpy array."
102
+ )
103
+ return audio
104
+
105
+
106
+ def load_audio_torchcodec(audio: str | np.ndarray, sampling_rate=16000, timeout=None) -> np.ndarray:
107
+ """
108
+ Loads `audio` to an np.ndarray object using `torchcodec`.
109
+
110
+ Args:
111
+ audio (`str` or `np.ndarray`):
112
+ The audio to be loaded to the numpy array format.
113
+ sampling_rate (`int`, *optional*, defaults to 16000):
114
+ The sampling rate to be used when loading the audio. It should be same as the
115
+ sampling rate the model you will be using further was trained with.
116
+ timeout (`float`, *optional*):
117
+ The timeout value in seconds for the URL request.
118
+
119
+ Returns:
120
+ `np.ndarray`: A numpy array representing the audio.
121
+ """
122
+ # Lazy import so that issues in torchcodec compatibility don't crash the whole library
123
+ requires_backends(load_audio_torchcodec, ["torchcodec"])
124
+ from torchcodec.decoders import AudioDecoder
125
+
126
+ # Fetch bytes for URLs so we get retry logic; torchcodec does not surface ffmpeg network retries options
127
+ if isinstance(audio, str) and audio.startswith(("http://", "https://")):
128
+ audio = _fetch_audio_bytes(audio, timeout=timeout)
129
+
130
+ # Set `num_channels` to `1` which is what most models expects and the default in librosa
131
+ decoder = AudioDecoder(audio, sample_rate=sampling_rate, num_channels=1)
132
+ audio = decoder.get_all_samples().data[0].numpy() # NOTE: feature extractors don't accept torch tensors
133
+ return audio
134
+
135
+
136
+ def load_audio_librosa(audio: str | np.ndarray, sampling_rate=16000, timeout=None) -> np.ndarray:
137
+ """
138
+ Loads `audio` to an np.ndarray object using `librosa`.
139
+
140
+ Args:
141
+ audio (`str` or `np.ndarray`):
142
+ The audio to be loaded to the numpy array format.
143
+ sampling_rate (`int`, *optional*, defaults to 16000):
144
+ The sampling rate to be used when loading the audio. It should be same as the
145
+ sampling rate the model you will be using further was trained with.
146
+ timeout (`float`, *optional*):
147
+ The timeout value in seconds for the URL request.
148
+
149
+ Returns:
150
+ `np.ndarray`: A numpy array representing the audio.
151
+ """
152
+ requires_backends(load_audio_librosa, ["librosa"])
153
+
154
+ # Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav)
155
+ if audio.startswith("http://") or audio.startswith("https://"):
156
+ audio = librosa.load(BytesIO(_fetch_audio_bytes(audio, timeout=timeout)), sr=sampling_rate)[0]
157
+ elif os.path.isfile(audio):
158
+ audio = librosa.load(audio, sr=sampling_rate)[0]
159
+ return audio
160
+
161
+
162
+ def load_audio_as(
163
+ audio: str,
164
+ return_format: str,
165
+ timeout: int | None = None,
166
+ force_mono: bool = False,
167
+ sampling_rate: int | None = None,
168
+ ) -> str | dict[str, Any] | io.BytesIO | None:
169
+ """
170
+ Load audio from either a local file path or URL and return in specified format.
171
+
172
+ Args:
173
+ audio (`str`): Either a local file path or a URL to an audio file
174
+ return_format (`str`): Format to return the audio in:
175
+ - "base64": Base64 encoded string
176
+ - "dict": Dictionary with data and format
177
+ - "buffer": BytesIO object
178
+ timeout (`int`, *optional*): Timeout for URL requests in seconds
179
+ force_mono (`bool`): Whether to convert stereo audio to mono
180
+ sampling_rate (`int`, *optional*): If provided, the audio will be resampled to the specified sampling rate.
181
+
182
+ Returns:
183
+ `Union[str, Dict[str, Any], io.BytesIO, None]`:
184
+ - `str`: Base64 encoded audio data (if return_format="base64")
185
+ - `dict`: Dictionary with 'data' (base64 encoded audio data) and 'format' keys (if return_format="dict")
186
+ - `io.BytesIO`: BytesIO object containing audio data (if return_format="buffer")
187
+ """
188
+ requires_backends(load_audio_as, ["librosa"])
189
+
190
+ if return_format not in ["base64", "dict", "buffer"]:
191
+ raise ValueError(f"Invalid return_format: {return_format}. Must be 'base64', 'dict', or 'buffer'")
192
+
193
+ try:
194
+ # Load audio bytes from URL or file
195
+ audio_bytes = None
196
+ if audio.startswith(("http://", "https://")):
197
+ audio_bytes = _fetch_audio_bytes(audio, timeout=timeout)
198
+ elif os.path.isfile(audio):
199
+ with open(audio, "rb") as audio_file:
200
+ audio_bytes = audio_file.read()
201
+ else:
202
+ raise ValueError(f"File not found: {audio}")
203
+
204
+ # Process audio data
205
+ with io.BytesIO(audio_bytes) as audio_file:
206
+ with sf.SoundFile(audio_file) as f:
207
+ audio_array = f.read(dtype="float32")
208
+ original_sr = f.samplerate
209
+ audio_format = f.format
210
+ if sampling_rate is not None and sampling_rate != original_sr:
211
+ # Resample audio to target sampling rate
212
+ audio_array = soxr.resample(audio_array, original_sr, sampling_rate, quality="HQ")
213
+ else:
214
+ sampling_rate = original_sr
215
+
216
+ # Convert to mono if needed
217
+ if force_mono and audio_array.ndim != 1:
218
+ audio_array = audio_array.mean(axis=1)
219
+
220
+ buffer = io.BytesIO()
221
+ sf.write(buffer, audio_array, sampling_rate, format=audio_format.upper())
222
+ buffer.seek(0)
223
+
224
+ if return_format == "buffer":
225
+ return buffer
226
+ elif return_format == "base64":
227
+ return base64.b64encode(buffer.read()).decode("utf-8")
228
+ elif return_format == "dict":
229
+ return {
230
+ "data": base64.b64encode(buffer.read()).decode("utf-8"),
231
+ "format": audio_format.lower(),
232
+ }
233
+
234
+ except Exception as e:
235
+ raise ValueError(f"Error loading audio: {e}")
236
+
237
+
238
+ def conv1d_output_length(module: "torch.nn.Conv1d", input_length: int) -> int:
239
+ """
240
+ Computes the output length of a 1D convolution layer according to torch's documentation:
241
+ https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
242
+ """
243
+ return int(
244
+ (input_length + 2 * module.padding[0] - module.dilation[0] * (module.kernel_size[0] - 1) - 1)
245
+ / module.stride[0]
246
+ + 1
247
+ )
248
+
249
+
250
+ def is_valid_audio(audio):
251
+ return is_numpy_array(audio) or is_torch_tensor(audio)
252
+
253
+
254
+ def is_valid_list_of_audio(audio):
255
+ return audio and all(is_valid_audio(audio_i) for audio_i in audio)
256
+
257
+
258
+ def make_list_of_audio(
259
+ audio: list[AudioInput] | AudioInput,
260
+ ) -> AudioInput:
261
+ """
262
+ Ensure that the output is a list of audio.
263
+ Args:
264
+ audio (`Union[list[AudioInput], AudioInput]`):
265
+ The input audio.
266
+ Returns:
267
+ list: A list of audio.
268
+ """
269
+ # If it's a list of audios, it's already in the right format
270
+ if isinstance(audio, (list, tuple)) and is_valid_list_of_audio(audio):
271
+ return audio
272
+
273
+ # If it's a single audio, convert it to a list of
274
+ if is_valid_audio(audio):
275
+ return [audio]
276
+
277
+ raise ValueError("Invalid input type. Must be a single audio or a list of audio")
278
+
279
+
280
+ def hertz_to_mel(freq: float | np.ndarray, mel_scale: str = "htk") -> float | np.ndarray:
281
+ """
282
+ Convert frequency from hertz to mels.
283
+
284
+ Args:
285
+ freq (`float` or `np.ndarray`):
286
+ The frequency, or multiple frequencies, in hertz (Hz).
287
+ mel_scale (`str`, *optional*, defaults to `"htk"`):
288
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
289
+
290
+ Returns:
291
+ `float` or `np.ndarray`: The frequencies on the mel scale.
292
+ """
293
+
294
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
295
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
296
+
297
+ if mel_scale == "htk":
298
+ return 2595.0 * np.log10(1.0 + (freq / 700.0))
299
+ elif mel_scale == "kaldi":
300
+ return 1127.0 * np.log(1.0 + (freq / 700.0))
301
+
302
+ min_log_hertz = 1000.0
303
+ min_log_mel = 15.0
304
+ logstep = 27.0 / np.log(6.4)
305
+ mels = 3.0 * freq / 200.0
306
+
307
+ if isinstance(freq, np.ndarray):
308
+ log_region = freq >= min_log_hertz
309
+ mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
310
+ elif freq >= min_log_hertz:
311
+ mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
312
+
313
+ return mels
314
+
315
+
316
+ def mel_to_hertz(mels: float | np.ndarray, mel_scale: str = "htk") -> float | np.ndarray:
317
+ """
318
+ Convert frequency from mels to hertz.
319
+
320
+ Args:
321
+ mels (`float` or `np.ndarray`):
322
+ The frequency, or multiple frequencies, in mels.
323
+ mel_scale (`str`, *optional*, `"htk"`):
324
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
325
+
326
+ Returns:
327
+ `float` or `np.ndarray`: The frequencies in hertz.
328
+ """
329
+
330
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
331
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
332
+
333
+ if mel_scale == "htk":
334
+ return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
335
+ elif mel_scale == "kaldi":
336
+ return 700.0 * (np.exp(mels / 1127.0) - 1.0)
337
+
338
+ min_log_hertz = 1000.0
339
+ min_log_mel = 15.0
340
+ logstep = np.log(6.4) / 27.0
341
+ freq = 200.0 * mels / 3.0
342
+
343
+ if isinstance(mels, np.ndarray):
344
+ log_region = mels >= min_log_mel
345
+ freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
346
+ elif mels >= min_log_mel:
347
+ freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
348
+
349
+ return freq
350
+
351
+
352
+ def hertz_to_octave(freq: float | np.ndarray, tuning: float = 0.0, bins_per_octave: int = 12):
353
+ """
354
+ Convert frequency from hertz to fractional octave numbers.
355
+ Adapted from *librosa*.
356
+
357
+ Args:
358
+ freq (`float` or `np.ndarray`):
359
+ The frequency, or multiple frequencies, in hertz (Hz).
360
+ tuning (`float`, defaults to `0.`):
361
+ Tuning deviation from the Stuttgart pitch (A440) in (fractional) bins per octave.
362
+ bins_per_octave (`int`, defaults to `12`):
363
+ Number of bins per octave.
364
+
365
+ Returns:
366
+ `float` or `np.ndarray`: The frequencies on the octave scale.
367
+ """
368
+ stuttgart_pitch = 440.0 * 2.0 ** (tuning / bins_per_octave)
369
+ octave = np.log2(freq / (float(stuttgart_pitch) / 16))
370
+ return octave
371
+
372
+
373
+ def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
374
+ """
375
+ Creates a triangular filter bank.
376
+
377
+ Adapted from *torchaudio* and *librosa*.
378
+
379
+ Args:
380
+ fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
381
+ Discrete frequencies of the FFT bins in Hz.
382
+ filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
383
+ Center frequencies of the triangular filters to create, in Hz.
384
+
385
+ Returns:
386
+ `np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
387
+ """
388
+ filter_diff = np.diff(filter_freqs)
389
+ slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
390
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
391
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
392
+ return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
393
+
394
+
395
+ def chroma_filter_bank(
396
+ num_frequency_bins: int,
397
+ num_chroma: int,
398
+ sampling_rate: int,
399
+ tuning: float = 0.0,
400
+ power: float | None = 2.0,
401
+ weighting_parameters: tuple[float, float] | None = (5.0, 2.0),
402
+ start_at_c_chroma: bool = True,
403
+ ):
404
+ """
405
+ Creates a chroma filter bank, i.e a linear transformation to project spectrogram bins onto chroma bins.
406
+
407
+ Adapted from *librosa*.
408
+
409
+ Args:
410
+ num_frequency_bins (`int`):
411
+ Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
412
+ num_chroma (`int`):
413
+ Number of chroma bins (i.e pitch classes).
414
+ sampling_rate (`float`):
415
+ Sample rate of the audio waveform.
416
+ tuning (`float`):
417
+ Tuning deviation from A440 in fractions of a chroma bin.
418
+ power (`float`, *optional*, defaults to 2.0):
419
+ If 12.0, normalizes each column with their L2 norm. If 1.0, normalizes each column with their L1 norm.
420
+ weighting_parameters (`tuple[float, float]`, *optional*, defaults to `(5., 2.)`):
421
+ If specified, apply a Gaussian weighting parameterized by the first element of the tuple being the center and
422
+ the second element being the Gaussian half-width.
423
+ start_at_c_chroma (`bool`, *optional*, defaults to `True`):
424
+ If True, the filter bank will start at the 'C' pitch class. Otherwise, it will start at 'A'.
425
+ Returns:
426
+ `np.ndarray` of shape `(num_frequency_bins, num_chroma)`
427
+ """
428
+ # Get the FFT bins, not counting the DC component
429
+ frequencies = np.linspace(0, sampling_rate, num_frequency_bins, endpoint=False)[1:]
430
+
431
+ freq_bins = num_chroma * hertz_to_octave(frequencies, tuning=tuning, bins_per_octave=num_chroma)
432
+
433
+ # make up a value for the 0 Hz bin = 1.5 octaves below bin 1
434
+ # (so chroma is 50% rotated from bin 1, and bin width is broad)
435
+ freq_bins = np.concatenate(([freq_bins[0] - 1.5 * num_chroma], freq_bins))
436
+
437
+ bins_width = np.concatenate((np.maximum(freq_bins[1:] - freq_bins[:-1], 1.0), [1]))
438
+
439
+ chroma_filters = np.subtract.outer(freq_bins, np.arange(0, num_chroma, dtype="d")).T
440
+
441
+ num_chroma2 = np.round(float(num_chroma) / 2)
442
+
443
+ # Project into range -num_chroma/2 .. num_chroma/2
444
+ # add on fixed offset of 10*num_chroma to ensure all values passed to
445
+ # rem are positive
446
+ chroma_filters = np.remainder(chroma_filters + num_chroma2 + 10 * num_chroma, num_chroma) - num_chroma2
447
+
448
+ # Gaussian bumps - 2*D to make them narrower
449
+ chroma_filters = np.exp(-0.5 * (2 * chroma_filters / np.tile(bins_width, (num_chroma, 1))) ** 2)
450
+
451
+ # normalize each column
452
+ if power is not None:
453
+ chroma_filters = chroma_filters / np.sum(chroma_filters**power, axis=0, keepdims=True) ** (1.0 / power)
454
+
455
+ # Maybe apply scaling for fft bins
456
+ if weighting_parameters is not None:
457
+ center, half_width = weighting_parameters
458
+ chroma_filters *= np.tile(
459
+ np.exp(-0.5 * (((freq_bins / num_chroma - center) / half_width) ** 2)),
460
+ (num_chroma, 1),
461
+ )
462
+
463
+ if start_at_c_chroma:
464
+ chroma_filters = np.roll(chroma_filters, -3 * (num_chroma // 12), axis=0)
465
+
466
+ # remove aliasing columns, copy to ensure row-contiguity
467
+ return np.ascontiguousarray(chroma_filters[:, : int(1 + num_frequency_bins / 2)])
468
+
469
+
470
+ def mel_filter_bank(
471
+ num_frequency_bins: int,
472
+ num_mel_filters: int,
473
+ min_frequency: float,
474
+ max_frequency: float,
475
+ sampling_rate: int,
476
+ norm: str | None = None,
477
+ mel_scale: str = "htk",
478
+ triangularize_in_mel_space: bool = False,
479
+ ) -> np.ndarray:
480
+ """
481
+ Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
482
+ various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters
483
+ are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
484
+ features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
485
+
486
+ Different banks of mel filters were introduced in the literature. The following variations are supported:
487
+
488
+ - MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech
489
+ bandwidth of `[0, 4600]` Hz.
490
+ - MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech
491
+ bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz.
492
+ - MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and
493
+ speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization.
494
+ - HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of
495
+ 12.5 kHz and speech bandwidth of `[0, 6250]` Hz.
496
+
497
+ This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's
498
+ `melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation.
499
+
500
+ Args:
501
+ num_frequency_bins (`int`):
502
+ Number of frequency bins (should be the same as `n_fft // 2 + 1` where `n_fft` is the size of the Fourier Transform used to compute the spectrogram).
503
+ num_mel_filters (`int`):
504
+ Number of mel filters to generate.
505
+ min_frequency (`float`):
506
+ Lowest frequency of interest in Hz.
507
+ max_frequency (`float`):
508
+ Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`.
509
+ sampling_rate (`int`):
510
+ Sample rate of the audio waveform.
511
+ norm (`str`, *optional*):
512
+ If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
513
+ mel_scale (`str`, *optional*, defaults to `"htk"`):
514
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
515
+ triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
516
+ If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
517
+ should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
518
+
519
+ Returns:
520
+ `np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
521
+ projection matrix to go from a spectrogram to a mel spectrogram.
522
+ """
523
+ if norm is not None and norm != "slaney":
524
+ raise ValueError('norm must be one of None or "slaney"')
525
+
526
+ if num_frequency_bins < 2:
527
+ raise ValueError(f"Require num_frequency_bins: {num_frequency_bins} >= 2")
528
+
529
+ if min_frequency > max_frequency:
530
+ raise ValueError(f"Require min_frequency: {min_frequency} <= max_frequency: {max_frequency}")
531
+
532
+ # center points of the triangular mel filters
533
+ mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
534
+ mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
535
+ mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
536
+ filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
537
+
538
+ if triangularize_in_mel_space:
539
+ # frequencies of FFT bins in Hz, but filters triangularized in mel space
540
+ fft_bin_width = sampling_rate / ((num_frequency_bins - 1) * 2)
541
+ fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
542
+ filter_freqs = mel_freqs
543
+ else:
544
+ # frequencies of FFT bins in Hz
545
+ fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
546
+
547
+ mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
548
+
549
+ if norm is not None and norm == "slaney":
550
+ # Slaney-style mel is scaled to be approx constant energy per channel
551
+ enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
552
+ mel_filters *= np.expand_dims(enorm, 0)
553
+
554
+ if (mel_filters.max(axis=0) == 0.0).any():
555
+ warnings.warn(
556
+ "At least one mel filter has all zero values. "
557
+ f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
558
+ f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
559
+ )
560
+
561
+ return mel_filters
562
+
563
+
564
+ def optimal_fft_length(window_length: int) -> int:
565
+ """
566
+ Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
567
+ already a power of two, rounds it up to the next power or two.
568
+
569
+ The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
570
+ of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
571
+ is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
572
+ it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
573
+ """
574
+ return 2 ** int(np.ceil(np.log2(window_length)))
575
+
576
+
577
+ def window_function(
578
+ window_length: int,
579
+ name: str = "hann",
580
+ periodic: bool = True,
581
+ frame_length: int | None = None,
582
+ center: bool = True,
583
+ ) -> np.ndarray:
584
+ """
585
+ Returns an array containing the specified window. This window is intended to be used with `stft`.
586
+
587
+ The following window types are supported:
588
+
589
+ - `"boxcar"`: a rectangular window
590
+ - `"hamming"`: the Hamming window
591
+ - `"hann"`: the Hann window
592
+ - `"povey"`: the Povey window
593
+
594
+ Args:
595
+ window_length (`int`):
596
+ The length of the window in samples.
597
+ name (`str`, *optional*, defaults to `"hann"`):
598
+ The name of the window function.
599
+ periodic (`bool`, *optional*, defaults to `True`):
600
+ Whether the window is periodic or symmetric.
601
+ frame_length (`int`, *optional*):
602
+ The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
603
+ than the frame length, so that it will be zero-padded.
604
+ center (`bool`, *optional*, defaults to `True`):
605
+ Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
606
+
607
+ Returns:
608
+ `np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
609
+ """
610
+ length = window_length + 1 if periodic else window_length
611
+
612
+ if name == "boxcar":
613
+ window = np.ones(length)
614
+ elif name in ["hamming", "hamming_window"]:
615
+ window = np.hamming(length)
616
+ elif name in ["hann", "hann_window"]:
617
+ window = np.hanning(length)
618
+ elif name == "povey":
619
+ window = np.power(np.hanning(length), 0.85)
620
+ else:
621
+ raise ValueError(f"Unknown window function '{name}'")
622
+
623
+ if periodic:
624
+ window = window[:-1]
625
+
626
+ if frame_length is None:
627
+ return window
628
+
629
+ if window_length > frame_length:
630
+ raise ValueError(
631
+ f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
632
+ )
633
+
634
+ padded_window = np.zeros(frame_length)
635
+ offset = (frame_length - window_length) // 2 if center else 0
636
+ padded_window[offset : offset + window_length] = window
637
+ return padded_window
638
+
639
+
640
+ # Note: This method processes a single waveform. For batch processing, use spectrogram_batch().
641
+ def spectrogram(
642
+ waveform: np.ndarray,
643
+ window: np.ndarray,
644
+ frame_length: int,
645
+ hop_length: int,
646
+ fft_length: int | None = None,
647
+ power: float | None = 1.0,
648
+ center: bool = True,
649
+ pad_mode: str = "reflect",
650
+ onesided: bool = True,
651
+ dither: float = 0.0,
652
+ preemphasis: float | None = None,
653
+ mel_filters: np.ndarray | None = None,
654
+ mel_floor: float = 1e-10,
655
+ log_mel: str | None = None,
656
+ reference: float = 1.0,
657
+ min_value: float = 1e-10,
658
+ db_range: float | None = None,
659
+ remove_dc_offset: bool = False,
660
+ dtype: np.dtype = np.float32,
661
+ ) -> np.ndarray:
662
+ """
663
+ Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
664
+
665
+ This function can create the following kinds of spectrograms:
666
+
667
+ - amplitude spectrogram (`power = 1.0`)
668
+ - power spectrogram (`power = 2.0`)
669
+ - complex-valued spectrogram (`power = None`)
670
+ - log spectrogram (use `log_mel` argument)
671
+ - mel spectrogram (provide `mel_filters`)
672
+ - log-mel spectrogram (provide `mel_filters` and `log_mel`)
673
+
674
+ How this works:
675
+
676
+ 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
677
+ - hop_length` samples.
678
+ 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
679
+ 3. The DFT is taken of each windowed frame.
680
+ 4. The results are stacked into a spectrogram.
681
+
682
+ We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
683
+
684
+ - The analysis frame. This is the size of the time slices that the input waveform is split into.
685
+ - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
686
+ - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
687
+
688
+ In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
689
+ padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
690
+ typically the next power of two.
691
+
692
+ Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
693
+ `torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
694
+ can be constructed.
695
+
696
+ Args:
697
+ waveform (`np.ndarray` of shape `(length,)`):
698
+ The input waveform. This must be a single real-valued, mono waveform.
699
+ window (`np.ndarray` of shape `(frame_length,)`):
700
+ The windowing function to apply, including zero-padding if necessary. The actual window length may be
701
+ shorter than `frame_length`, but we're assuming the array has already been zero-padded.
702
+ frame_length (`int`):
703
+ The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also
704
+ allow smaller sizes.
705
+ hop_length (`int`):
706
+ The stride between successive analysis frames in samples.
707
+ fft_length (`int`, *optional*):
708
+ The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
709
+ For optimal speed, this should be a power of two. If `None`, uses `frame_length`.
710
+ power (`float`, *optional*, defaults to 1.0):
711
+ If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns
712
+ complex numbers.
713
+ center (`bool`, *optional*, defaults to `True`):
714
+ Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
715
+ `t` will start at time `t * hop_length`.
716
+ pad_mode (`str`, *optional*, defaults to `"reflect"`):
717
+ Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"`
718
+ (pad with edge values), `"reflect"` (pads with mirrored values).
719
+ onesided (`bool`, *optional*, defaults to `True`):
720
+ If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
721
+ frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
722
+ dither (`float`, *optional*, defaults to 0.0):
723
+ Adds dithering. In other words, adds a small Gaussian noise to each frame.
724
+ E.g. use 4.0 to add dithering with a normal distribution centered
725
+ around 0.0 with standard deviation 4.0, 0.0 means no dithering.
726
+ Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
727
+ values for signals with hard-zero sections, when VAD cutoff is present in the signal.
728
+ preemphasis (`float`, *optional*)
729
+ Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
730
+ mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
731
+ The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram.
732
+ mel_floor (`float`, *optional*, defaults to 1e-10):
733
+ Minimum value of mel frequency banks.
734
+ log_mel (`str`, *optional*):
735
+ How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take
736
+ the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be
737
+ used when `power` is not `None`.
738
+ reference (`float`, *optional*, defaults to 1.0):
739
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
740
+ the loudest part to 0 dB. Must be greater than zero.
741
+ min_value (`float`, *optional*, defaults to `1e-10`):
742
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
743
+ `log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an
744
+ amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero.
745
+ db_range (`float`, *optional*):
746
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
747
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
748
+ remove_dc_offset (`bool`, *optional*):
749
+ Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
750
+ order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
751
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
752
+ Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
753
+ `np.complex64`.
754
+
755
+ Returns:
756
+ `nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
757
+ `(num_mel_filters, length)` for a mel spectrogram.
758
+ """
759
+ window_length = len(window)
760
+
761
+ if fft_length is None:
762
+ fft_length = frame_length
763
+
764
+ if frame_length > fft_length:
765
+ raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
766
+
767
+ if window_length != frame_length:
768
+ raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
769
+
770
+ if hop_length <= 0:
771
+ raise ValueError("hop_length must be greater than zero")
772
+
773
+ if waveform.ndim != 1:
774
+ raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
775
+
776
+ if np.iscomplexobj(waveform):
777
+ raise ValueError("Complex-valued input waveforms are not currently supported")
778
+
779
+ if power is None and mel_filters is not None:
780
+ raise ValueError(
781
+ "You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram."
782
+ "Specify `power` to fix this issue."
783
+ )
784
+
785
+ # center pad the waveform
786
+ if center:
787
+ padding = [(int(frame_length // 2), int(frame_length // 2))]
788
+ waveform = np.pad(waveform, padding, mode=pad_mode)
789
+
790
+ # promote to float64, since np.fft uses float64 internally
791
+ waveform = waveform.astype(np.float64)
792
+ window = window.astype(np.float64)
793
+
794
+ # split waveform into frames of frame_length size
795
+ num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
796
+
797
+ num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
798
+ spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
799
+
800
+ # rfft is faster than fft
801
+ fft_func = np.fft.rfft if onesided else np.fft.fft
802
+ buffer = np.zeros(fft_length)
803
+
804
+ timestep = 0
805
+ for frame_idx in range(num_frames):
806
+ buffer[:frame_length] = waveform[timestep : timestep + frame_length]
807
+
808
+ if dither != 0.0:
809
+ buffer[:frame_length] += dither * np.random.randn(frame_length)
810
+
811
+ if remove_dc_offset:
812
+ buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
813
+
814
+ if preemphasis is not None:
815
+ buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
816
+ buffer[0] *= 1 - preemphasis
817
+
818
+ buffer[:frame_length] *= window
819
+
820
+ spectrogram[frame_idx] = fft_func(buffer)
821
+ timestep += hop_length
822
+
823
+ # note: ** is much faster than np.power
824
+ if power is not None:
825
+ spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
826
+
827
+ spectrogram = spectrogram.T
828
+
829
+ if mel_filters is not None:
830
+ spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
831
+
832
+ if power is not None and log_mel is not None:
833
+ if log_mel == "log":
834
+ spectrogram = np.log(spectrogram)
835
+ elif log_mel == "log10":
836
+ spectrogram = np.log10(spectrogram)
837
+ elif log_mel == "dB":
838
+ if power == 1.0:
839
+ spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
840
+ elif power == 2.0:
841
+ spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
842
+ else:
843
+ raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
844
+ else:
845
+ raise ValueError(f"Unknown log_mel option: {log_mel}")
846
+
847
+ spectrogram = np.asarray(spectrogram, dtype)
848
+
849
+ return spectrogram
850
+
851
+
852
+ def spectrogram_batch(
853
+ waveform_list: list[np.ndarray],
854
+ window: np.ndarray,
855
+ frame_length: int,
856
+ hop_length: int,
857
+ fft_length: int | None = None,
858
+ power: float | None = 1.0,
859
+ center: bool = True,
860
+ pad_mode: str = "reflect",
861
+ onesided: bool = True,
862
+ dither: float = 0.0,
863
+ preemphasis: float | None = None,
864
+ mel_filters: np.ndarray | None = None,
865
+ mel_floor: float = 1e-10,
866
+ log_mel: str | None = None,
867
+ reference: float = 1.0,
868
+ min_value: float = 1e-10,
869
+ db_range: float | None = None,
870
+ remove_dc_offset: bool = False,
871
+ dtype: np.dtype = np.float32,
872
+ ) -> list[np.ndarray]:
873
+ """
874
+ Calculates spectrograms for a list of waveforms using the Short-Time Fourier Transform, optimized for batch processing.
875
+ This function extends the capabilities of the `spectrogram` function to handle multiple waveforms efficiently by leveraging broadcasting.
876
+
877
+ It supports generating various types of spectrograms:
878
+
879
+ - amplitude spectrogram (`power = 1.0`)
880
+ - power spectrogram (`power = 2.0`)
881
+ - complex-valued spectrogram (`power = None`)
882
+ - log spectrogram (use `log_mel` argument)
883
+ - mel spectrogram (provide `mel_filters`)
884
+ - log-mel spectrogram (provide `mel_filters` and `log_mel`)
885
+
886
+ How this works:
887
+
888
+ 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
889
+ - hop_length` samples.
890
+ 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
891
+ 3. The DFT is taken of each windowed frame.
892
+ 4. The results are stacked into a spectrogram.
893
+
894
+ We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
895
+
896
+ - The analysis frame. This is the size of the time slices that the input waveform is split into.
897
+ - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
898
+ - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
899
+
900
+ In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
901
+ padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
902
+ typically the next power of two.
903
+
904
+ Note: This function is designed for efficient batch processing of multiple waveforms but retains compatibility with individual waveform processing methods like `librosa.stft`.
905
+
906
+ Args:
907
+ waveform_list (`list[np.ndarray]` with arrays of shape `(length,)`):
908
+ The list of input waveforms, each a single-channel (mono) signal.
909
+ window (`np.ndarray` of shape `(frame_length,)`):
910
+ The windowing function to apply, including zero-padding if necessary.
911
+ frame_length (`int`):
912
+ The length of each frame for analysis.
913
+ hop_length (`int`):
914
+ The step size between successive frames.
915
+ fft_length (`int`, *optional*):
916
+ The size of the FFT buffer, defining frequency bin resolution.
917
+ power (`float`, *optional*, defaults to 1.0):
918
+ Determines the type of spectrogram: 1.0 for amplitude, 2.0 for power, None for complex.
919
+ center (`bool`, *optional*, defaults to `True`):
920
+ Whether to center-pad the waveform frames.
921
+ pad_mode (`str`, *optional*, defaults to `"reflect"`):
922
+ The padding strategy when `center` is `True`.
923
+ onesided (`bool`, *optional*, defaults to `True`):
924
+ If True, returns a one-sided spectrogram for real input signals.
925
+ dither (`float`, *optional*, defaults to 0.0):
926
+ Adds dithering. In other words, adds a small Gaussian noise to each frame.
927
+ E.g. use 4.0 to add dithering with a normal distribution centered
928
+ around 0.0 with standard deviation 4.0, 0.0 means no dithering.
929
+ preemphasis (`float`, *optional*):
930
+ Applies a pre-emphasis filter to each frame.
931
+ mel_filters (`np.ndarray`, *optional*):
932
+ Mel filter bank for converting to mel spectrogram.
933
+ mel_floor (`float`, *optional*, defaults to 1e-10):
934
+ Floor value for mel spectrogram to avoid log(0).
935
+ log_mel (`str`, *optional*):
936
+ Specifies log scaling strategy; options are None, "log", "log10", "dB".
937
+ reference (`float`, *optional*, defaults to 1.0):
938
+ Reference value for dB conversion in log_mel.
939
+ min_value (`float`, *optional*, defaults to 1e-10):
940
+ Minimum floor value for log scale conversions.
941
+ db_range (`float`, *optional*):
942
+ Dynamic range for dB scale spectrograms.
943
+ remove_dc_offset (`bool`, *optional*):
944
+ Whether to remove the DC offset from each frame.
945
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
946
+ Data type of the output spectrogram.
947
+
948
+ Returns:
949
+ list[`np.ndarray`]: A list of spectrogram arrays, one for each input waveform.
950
+ """
951
+ window_length = len(window)
952
+
953
+ if fft_length is None:
954
+ fft_length = frame_length
955
+
956
+ if frame_length > fft_length:
957
+ raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
958
+
959
+ if window_length != frame_length:
960
+ raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
961
+
962
+ if hop_length <= 0:
963
+ raise ValueError("hop_length must be greater than zero")
964
+
965
+ # Check the dimensions of the waveform , and if waveform is complex
966
+ for waveform in waveform_list:
967
+ if waveform.ndim != 1:
968
+ raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
969
+ if np.iscomplexobj(waveform):
970
+ raise ValueError("Complex-valued input waveforms are not currently supported")
971
+ # Center pad the waveform
972
+ if center:
973
+ padding = [(int(frame_length // 2), int(frame_length // 2))]
974
+ waveform_list = [
975
+ np.pad(
976
+ waveform,
977
+ padding,
978
+ mode=pad_mode,
979
+ )
980
+ for waveform in waveform_list
981
+ ]
982
+ original_waveform_lengths = [
983
+ len(waveform) for waveform in waveform_list
984
+ ] # these lengths will be used to remove padding later
985
+
986
+ # Batch pad the waveform
987
+ max_length = max(original_waveform_lengths)
988
+ padded_waveform_batch = np.array(
989
+ [
990
+ np.pad(waveform, (0, max_length - len(waveform)), mode="constant", constant_values=0)
991
+ for waveform in waveform_list
992
+ ],
993
+ dtype=dtype,
994
+ )
995
+
996
+ # Promote to float64, since np.fft uses float64 internally
997
+ padded_waveform_batch = padded_waveform_batch.astype(np.float64)
998
+ window = window.astype(np.float64)
999
+
1000
+ # Split waveform into frames of frame_length size
1001
+ num_frames = int(1 + np.floor((padded_waveform_batch.shape[1] - frame_length) / hop_length))
1002
+ # these lengths will be used to remove padding later
1003
+ true_num_frames = [int(1 + np.floor((length - frame_length) / hop_length)) for length in original_waveform_lengths]
1004
+ num_batches = padded_waveform_batch.shape[0]
1005
+
1006
+ num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
1007
+ spectrogram = np.empty((num_batches, num_frames, num_frequency_bins), dtype=np.complex64)
1008
+
1009
+ # rfft is faster than fft
1010
+ fft_func = np.fft.rfft if onesided else np.fft.fft
1011
+ buffer = np.zeros((num_batches, fft_length))
1012
+
1013
+ for frame_idx in range(num_frames):
1014
+ timestep = frame_idx * hop_length
1015
+ buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]
1016
+
1017
+ if dither != 0.0:
1018
+ buffer[:, :frame_length] += dither * np.random.randn(*buffer[:, :frame_length].shape)
1019
+
1020
+ if remove_dc_offset:
1021
+ buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)
1022
+
1023
+ if preemphasis is not None:
1024
+ buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1]
1025
+ buffer[:, 0] *= 1 - preemphasis
1026
+
1027
+ buffer[:, :frame_length] *= window
1028
+
1029
+ spectrogram[:, frame_idx] = fft_func(buffer)
1030
+
1031
+ # Note: ** is much faster than np.power
1032
+ if power is not None:
1033
+ spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
1034
+
1035
+ # Apply mel filters if provided
1036
+ if mel_filters is not None:
1037
+ result = np.tensordot(spectrogram, mel_filters.T, axes=([2], [1]))
1038
+ spectrogram = np.maximum(mel_floor, result)
1039
+
1040
+ # Convert to log scale if specified
1041
+ if power is not None and log_mel is not None:
1042
+ if log_mel == "log":
1043
+ spectrogram = np.log(spectrogram)
1044
+ elif log_mel == "log10":
1045
+ spectrogram = np.log10(spectrogram)
1046
+ elif log_mel == "dB":
1047
+ if power == 1.0:
1048
+ spectrogram = amplitude_to_db_batch(spectrogram, reference, min_value, db_range)
1049
+ elif power == 2.0:
1050
+ spectrogram = power_to_db_batch(spectrogram, reference, min_value, db_range)
1051
+ else:
1052
+ raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
1053
+ else:
1054
+ raise ValueError(f"Unknown log_mel option: {log_mel}")
1055
+
1056
+ spectrogram = np.asarray(spectrogram, dtype)
1057
+
1058
+ spectrogram_list = [spectrogram[i, : true_num_frames[i], :].T for i in range(len(true_num_frames))]
1059
+
1060
+ return spectrogram_list
1061
+
1062
+
1063
+ def power_to_db(
1064
+ spectrogram: np.ndarray,
1065
+ reference: float = 1.0,
1066
+ min_value: float = 1e-10,
1067
+ db_range: float | None = None,
1068
+ ) -> np.ndarray:
1069
+ """
1070
+ Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
1071
+ logarithm properties for numerical stability.
1072
+
1073
+ The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
1074
+ linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
1075
+ This means that large variations in energy may not sound all that different if the sound is loud to begin with.
1076
+ This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
1077
+
1078
+ Based on the implementation of `librosa.power_to_db`.
1079
+
1080
+ Args:
1081
+ spectrogram (`np.ndarray`):
1082
+ The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
1083
+ reference (`float`, *optional*, defaults to 1.0):
1084
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
1085
+ the loudest part to 0 dB. Must be greater than zero.
1086
+ min_value (`float`, *optional*, defaults to `1e-10`):
1087
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
1088
+ `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
1089
+ db_range (`float`, *optional*):
1090
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
1091
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
1092
+
1093
+ Returns:
1094
+ `np.ndarray`: the spectrogram in decibels
1095
+ """
1096
+ if reference <= 0.0:
1097
+ raise ValueError("reference must be greater than zero")
1098
+ if min_value <= 0.0:
1099
+ raise ValueError("min_value must be greater than zero")
1100
+
1101
+ reference = max(min_value, reference)
1102
+
1103
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
1104
+ spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
1105
+
1106
+ if db_range is not None:
1107
+ if db_range <= 0.0:
1108
+ raise ValueError("db_range must be greater than zero")
1109
+ spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
1110
+
1111
+ return spectrogram
1112
+
1113
+
1114
+ def power_to_db_batch(
1115
+ spectrogram: np.ndarray,
1116
+ reference: float = 1.0,
1117
+ min_value: float = 1e-10,
1118
+ db_range: float | None = None,
1119
+ ) -> np.ndarray:
1120
+ """
1121
+ Converts a batch of power spectrograms to the decibel scale. This computes `10 * log10(spectrogram / reference)`,
1122
+ using basic logarithm properties for numerical stability.
1123
+
1124
+ This function supports batch processing, where each item in the batch is an individual power (mel) spectrogram.
1125
+
1126
+ Args:
1127
+ spectrogram (`np.ndarray`):
1128
+ The input batch of power (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
1129
+ Note that a power spectrogram has the amplitudes squared!
1130
+ reference (`float`, *optional*, defaults to 1.0):
1131
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
1132
+ the loudest part to 0 dB. Must be greater than zero.
1133
+ min_value (`float`, *optional*, defaults to `1e-10`):
1134
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
1135
+ `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
1136
+ db_range (`float`, *optional*):
1137
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
1138
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
1139
+
1140
+ Returns:
1141
+ `np.ndarray`: the batch of spectrograms in decibels
1142
+ """
1143
+ if reference <= 0.0:
1144
+ raise ValueError("reference must be greater than zero")
1145
+ if min_value <= 0.0:
1146
+ raise ValueError("min_value must be greater than zero")
1147
+
1148
+ reference = max(min_value, reference)
1149
+
1150
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
1151
+ spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
1152
+
1153
+ if db_range is not None:
1154
+ if db_range <= 0.0:
1155
+ raise ValueError("db_range must be greater than zero")
1156
+ # Apply db_range clipping per batch item
1157
+ max_values = spectrogram.max(axis=(1, 2), keepdims=True)
1158
+ spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
1159
+
1160
+ return spectrogram
1161
+
1162
+
1163
+ def amplitude_to_db(
1164
+ spectrogram: np.ndarray,
1165
+ reference: float = 1.0,
1166
+ min_value: float = 1e-5,
1167
+ db_range: float | None = None,
1168
+ ) -> np.ndarray:
1169
+ """
1170
+ Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
1171
+ basic logarithm properties for numerical stability.
1172
+
1173
+ The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
1174
+ linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
1175
+ This means that large variations in energy may not sound all that different if the sound is loud to begin with.
1176
+ This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
1177
+
1178
+ Args:
1179
+ spectrogram (`np.ndarray`):
1180
+ The input amplitude (mel) spectrogram.
1181
+ reference (`float`, *optional*, defaults to 1.0):
1182
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
1183
+ the loudest part to 0 dB. Must be greater than zero.
1184
+ min_value (`float`, *optional*, defaults to `1e-5`):
1185
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
1186
+ `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
1187
+ db_range (`float`, *optional*):
1188
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
1189
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
1190
+
1191
+ Returns:
1192
+ `np.ndarray`: the spectrogram in decibels
1193
+ """
1194
+ if reference <= 0.0:
1195
+ raise ValueError("reference must be greater than zero")
1196
+ if min_value <= 0.0:
1197
+ raise ValueError("min_value must be greater than zero")
1198
+
1199
+ reference = max(min_value, reference)
1200
+
1201
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
1202
+ spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
1203
+
1204
+ if db_range is not None:
1205
+ if db_range <= 0.0:
1206
+ raise ValueError("db_range must be greater than zero")
1207
+ spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
1208
+
1209
+ return spectrogram
1210
+
1211
+
1212
+ def amplitude_to_db_batch(
1213
+ spectrogram: np.ndarray, reference: float = 1.0, min_value: float = 1e-5, db_range: float | None = None
1214
+ ) -> np.ndarray:
1215
+ """
1216
+ Converts a batch of amplitude spectrograms to the decibel scale. This computes `20 * log10(spectrogram / reference)`,
1217
+ using basic logarithm properties for numerical stability.
1218
+
1219
+ The function supports batch processing, where each item in the batch is an individual amplitude (mel) spectrogram.
1220
+
1221
+ Args:
1222
+ spectrogram (`np.ndarray`):
1223
+ The input batch of amplitude (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
1224
+ reference (`float`, *optional*, defaults to 1.0):
1225
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
1226
+ the loudest part to 0 dB. Must be greater than zero.
1227
+ min_value (`float`, *optional*, defaults to `1e-5`):
1228
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
1229
+ `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
1230
+ db_range (`float`, *optional*):
1231
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
1232
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
1233
+
1234
+ Returns:
1235
+ `np.ndarray`: the batch of spectrograms in decibels
1236
+ """
1237
+ if reference <= 0.0:
1238
+ raise ValueError("reference must be greater than zero")
1239
+ if min_value <= 0.0:
1240
+ raise ValueError("min_value must be greater than zero")
1241
+
1242
+ reference = max(min_value, reference)
1243
+
1244
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
1245
+ spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
1246
+
1247
+ if db_range is not None:
1248
+ if db_range <= 0.0:
1249
+ raise ValueError("db_range must be greater than zero")
1250
+ # Apply db_range clipping per batch item
1251
+ max_values = spectrogram.max(axis=(1, 2), keepdims=True)
1252
+ spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
1253
+
1254
+ return spectrogram
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/add_new_model_like.py ADDED
@@ -0,0 +1,790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import difflib
15
+ import os
16
+ import re
17
+ import subprocess
18
+ import textwrap
19
+ from collections.abc import Callable
20
+ from datetime import date
21
+ from pathlib import Path
22
+ from typing import Annotated, Any, cast
23
+
24
+ import typer
25
+
26
+ from ..utils import is_libcst_available
27
+
28
+
29
+ # We protect this import to avoid requiring it for all `transformers` CLI commands - however it is actually
30
+ # strictly required for this one (we need it both for modular and for the following Visitor)
31
+ if is_libcst_available():
32
+ import libcst as cst
33
+ from libcst import CSTVisitor
34
+ from libcst import matchers as m
35
+
36
+ class ClassFinder(CSTVisitor):
37
+ """
38
+ A visitor to find all classes in a python module.
39
+ """
40
+
41
+ def __init__(self):
42
+ self.classes: list = []
43
+ self.public_classes: list = []
44
+ self.is_in_class = False
45
+
46
+ def visit_ClassDef(self, node: cst.ClassDef) -> None:
47
+ """Record class names. We assume classes always only appear at top-level (i.e. no class definition in function or similar)"""
48
+ self.classes.append(node.name.value)
49
+ self.is_in_class = True
50
+
51
+ def leave_ClassDef(self, node: cst.ClassDef):
52
+ self.is_in_class = False
53
+
54
+ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine):
55
+ """Record all public classes inside the `__all__` assignment."""
56
+ simple_top_level_assign_structure = m.SimpleStatementLine(
57
+ body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])]
58
+ )
59
+ if not self.is_in_class and m.matches(node, simple_top_level_assign_structure):
60
+ stmt = cast(cst.Assign, node.body[0])
61
+ assigned_variable = cast(cst.Name, stmt.targets[0].target).value
62
+ if assigned_variable == "__all__":
63
+ elements = cast(cst.Tuple, stmt.value).elements
64
+ self.public_classes = [cast(cst.SimpleString, element.value).value for element in elements]
65
+
66
+
67
+ CURRENT_YEAR = date.today().year
68
+ REPO_PATH = Path(__file__).parents[3]
69
+
70
+ COPYRIGHT = f"""
71
+ # coding=utf-8
72
+ # Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
73
+ #
74
+ # Licensed under the Apache License, Version 2.0 (the "License");
75
+ # you may not use this file except in compliance with the License.
76
+ # You may obtain a copy of the License at
77
+ #
78
+ # http://www.apache.org/licenses/LICENSE-2.0
79
+ #
80
+ # Unless required by applicable law or agreed to in writing, software
81
+ # distributed under the License is distributed on an "AS IS" BASIS,
82
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83
+ # See the License for the specific language governing permissions and
84
+ # limitations under the License.
85
+ """.lstrip()
86
+
87
+ ### Entrypoint
88
+
89
+
90
+ def add_new_model_like(
91
+ repo_path: Annotated[
92
+ str | None, typer.Argument(help="When not using an editable install, the path to the Transformers repo.")
93
+ ] = None,
94
+ ):
95
+ """
96
+ Add a new model to the library, based on an existing one.
97
+ """
98
+ (
99
+ old_model_infos,
100
+ new_lowercase_name,
101
+ new_model_paper_name,
102
+ filenames_to_add,
103
+ ) = get_user_input()
104
+
105
+ _add_new_model_like_internal(
106
+ repo_path=Path(repo_path) if repo_path is not None else REPO_PATH,
107
+ old_model_infos=old_model_infos,
108
+ new_lowercase_name=new_lowercase_name,
109
+ new_model_paper_name=new_model_paper_name,
110
+ filenames_to_add=filenames_to_add,
111
+ )
112
+
113
+
114
+ ### Core logic
115
+
116
+
117
+ class ModelInfos:
118
+ """
119
+ Retrieve the basic information about an existing model classes.
120
+ """
121
+
122
+ def __init__(self, lowercase_name: str):
123
+ from ..models.auto.configuration_auto import CONFIG_MAPPING_NAMES
124
+ from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
125
+ from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES
126
+ from ..models.auto.processing_auto import PROCESSOR_MAPPING_NAMES
127
+ from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
128
+ from ..models.auto.video_processing_auto import VIDEO_PROCESSOR_MAPPING_NAMES
129
+
130
+ # Just to make sure it's indeed lowercase
131
+ self.lowercase_name = lowercase_name.lower().replace(" ", "_").replace("-", "_")
132
+ if self.lowercase_name not in CONFIG_MAPPING_NAMES:
133
+ self.lowercase_name.replace("_", "-")
134
+ if self.lowercase_name not in CONFIG_MAPPING_NAMES:
135
+ raise ValueError(f"{lowercase_name} is not a valid model name")
136
+
137
+ self.config_class = CONFIG_MAPPING_NAMES[self.lowercase_name]
138
+ self.camelcase_name = self.config_class.replace("Config", "")
139
+
140
+ # Get tokenizer class
141
+ if self.lowercase_name in TOKENIZER_MAPPING_NAMES:
142
+ self.tokenizer_class = None
143
+ self.fast_tokenizer_class = TOKENIZER_MAPPING_NAMES[self.lowercase_name]
144
+ self.fast_tokenizer_class = (
145
+ None if self.fast_tokenizer_class == "PreTrainedTokenizerFast" else self.fast_tokenizer_class
146
+ )
147
+ else:
148
+ self.tokenizer_class, self.fast_tokenizer_class = None, None
149
+
150
+ self.image_processor_classes = IMAGE_PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None)
151
+ self.video_processor_class = VIDEO_PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None)
152
+ self.feature_extractor_class = FEATURE_EXTRACTOR_MAPPING_NAMES.get(self.lowercase_name, None)
153
+ self.processor_class = PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None)
154
+
155
+
156
+ def add_content_to_file(file_name: str | os.PathLike, new_content: str, add_after: str):
157
+ """
158
+ A utility to add some content inside a given file.
159
+
160
+ Args:
161
+ file_name (`str` or `os.PathLike`):
162
+ The name of the file in which we want to insert some content.
163
+ new_content (`str`):
164
+ The content to add.
165
+ add_after (`str`):
166
+ The new content is added just after the first instance matching it.
167
+ """
168
+ with open(file_name, "r", encoding="utf-8") as f:
169
+ old_content = f.read()
170
+
171
+ before, after = old_content.split(add_after, 1)
172
+ new_content = before + add_after + new_content + after
173
+
174
+ with open(file_name, "w", encoding="utf-8") as f:
175
+ f.write(new_content)
176
+
177
+
178
+ def add_model_to_auto_mappings(
179
+ repo_path: Path,
180
+ old_model_infos: ModelInfos,
181
+ new_lowercase_name: str,
182
+ new_model_paper_name: str,
183
+ filenames_to_add: list[tuple[str, bool]],
184
+ ):
185
+ """
186
+ Add a model to all the relevant mappings in the auto module.
187
+
188
+ Args:
189
+ old_model_infos (`ModelInfos`):
190
+ The structure containing the class information of the old model.
191
+ new_lowercase_name (`str`):
192
+ The new lowercase model name.
193
+ new_model_paper_name (`str`):
194
+ The fully cased name (as in the official paper name) of the new model.
195
+ filenames_to_add (`list[tuple[str, bool]]`):
196
+ A list of tuples of all potential filenames to add for a new model, along a boolean flag describing if we
197
+ should add this file or not. For example, [(`modeling_xxx.px`, True), (`configuration_xxx.py`, True), (`tokenization_xxx.py`, False),...]
198
+ """
199
+ new_cased_name = "".join(x.title() for x in new_lowercase_name.replace("-", "_").split("_"))
200
+ old_lowercase_name = old_model_infos.lowercase_name
201
+ old_cased_name = old_model_infos.camelcase_name
202
+ filenames_to_add = [
203
+ (filename.replace(old_lowercase_name, "auto"), to_add) for filename, to_add in filenames_to_add[1:]
204
+ ]
205
+ # fast tokenizer has the same auto mappings as normal ones
206
+ corrected_filenames_to_add = []
207
+ has_image_processor = has_video_processor = False
208
+ for file, to_add in filenames_to_add:
209
+ if "image_processing" in file:
210
+ has_image_processor = True
211
+ elif "video_processing" in file:
212
+ has_video_processor = True
213
+ elif re.search(r"(?:tokenization)|(?:image_processing)_auto_fast.py", file):
214
+ previous_file, previous_to_add = corrected_filenames_to_add[-1]
215
+ corrected_filenames_to_add[-1] = (previous_file, previous_to_add or to_add)
216
+ else:
217
+ corrected_filenames_to_add.append((file, to_add))
218
+
219
+ # Add the config and image/video processor mappings directly as the handling is a bit different
220
+ add_content_to_file(
221
+ repo_path / "src" / "transformers" / "models" / "auto" / "auto_mappings.py",
222
+ new_content=f'("{new_lowercase_name}", "{new_cased_name}Config"),\n ',
223
+ add_after="CONFIG_MAPPING_NAMES = OrderedDict(\n [\n ",
224
+ )
225
+ autofile = (repo_path / "src" / "transformers" / "models" / "auto" / "auto_mappings.py").read_text()
226
+ if has_image_processor:
227
+ matching_lines = re.findall(rf'^\s+\("{old_lowercase_name}",\s+{{[^}}]+}}\),?$', autofile, re.MULTILINE)
228
+ if matching_lines:
229
+ match = matching_lines[0]
230
+ add_content_to_file(
231
+ repo_path / "src" / "transformers" / "models" / "auto" / "auto_mappings.py",
232
+ new_content=match.replace(old_lowercase_name, new_lowercase_name).replace(
233
+ old_cased_name, new_cased_name
234
+ )
235
+ + "\n",
236
+ add_after="IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(\n [\n",
237
+ )
238
+ if has_video_processor:
239
+ # Extract the VIDEO_PROCESSOR_MAPPING_NAMES block first
240
+ block_match = re.search(
241
+ r"VIDEO_PROCESSOR_MAPPING_NAMES\s*=\s*OrderedDict\(\s*\[(.*?)\]\s*\)", autofile, re.DOTALL
242
+ )
243
+ block = block_match.group(1) # type: ignore
244
+ matching_lines = re.findall(rf'^\s+\("{old_lowercase_name}",\s+"[^"]+"\),?$', block, re.MULTILINE)
245
+ if matching_lines:
246
+ match = matching_lines[0]
247
+ add_content_to_file(
248
+ repo_path / "src" / "transformers" / "models" / "auto" / "auto_mappings.py",
249
+ new_content=match.replace(old_lowercase_name, new_lowercase_name).replace(
250
+ old_cased_name, new_cased_name
251
+ )
252
+ + "\n",
253
+ add_after="VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(\n [\n",
254
+ )
255
+
256
+ for filename, to_add in corrected_filenames_to_add:
257
+ if to_add:
258
+ # The auto mapping
259
+ filename = filename.replace("_fast.py", ".py")
260
+ file = (repo_path / "src" / "transformers" / "models" / "auto" / filename).read_text()
261
+ # The regex has to be a bit complex like this as the tokenizer mapping has new lines everywhere
262
+ matching_lines = re.findall(
263
+ rf'( {{8,12}}\(\s*"{old_lowercase_name}",.*?\),\n)(?: {{4,12}}\(|\])', file, re.DOTALL
264
+ )
265
+ for match in matching_lines:
266
+ add_content_to_file(
267
+ repo_path / "src" / "transformers" / "models" / "auto" / filename,
268
+ new_content=match.replace(old_lowercase_name, new_lowercase_name).replace(
269
+ old_cased_name, new_cased_name
270
+ ),
271
+ add_after=match,
272
+ )
273
+
274
+
275
+ def create_doc_file(new_paper_name: str, public_classes: list[str]):
276
+ """
277
+ Create a new doc file to fill for the new model.
278
+
279
+ Args:
280
+ new_paper_name (`str`):
281
+ The fully cased name (as in the official paper name) of the new model.
282
+ public_classes (`list[str]`):
283
+ A list of all the public classes that the model will have in the library.
284
+ """
285
+ added_note = (
286
+ "\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that "
287
+ "may not be rendered properly in your Markdown viewer.\n\n-->\n\n"
288
+ )
289
+ copyright_for_markdown = re.sub(r"# ?", "", COPYRIGHT).replace("coding=utf-8\n", "<!--") + added_note
290
+
291
+ doc_template = textwrap.dedent(
292
+ f"""
293
+ # {new_paper_name}
294
+
295
+ ## Overview
296
+
297
+ The {new_paper_name} model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
298
+ <INSERT SHORT SUMMARY HERE>
299
+
300
+ The abstract from the paper is the following:
301
+
302
+ <INSERT PAPER ABSTRACT HERE>
303
+
304
+ Tips:
305
+
306
+ <INSERT TIPS ABOUT MODEL HERE>
307
+
308
+ This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
309
+ The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
310
+
311
+ ## Usage examples
312
+
313
+ <INSERT SOME NICE EXAMPLES HERE>
314
+
315
+ """
316
+ )
317
+
318
+ # Add public classes doc
319
+ doc_for_classes = []
320
+ for class_ in public_classes:
321
+ doc = f"## {class_}\n\n[[autodoc]] {class_}"
322
+ if "Model" in class_:
323
+ doc += "\n - forward"
324
+ doc_for_classes.append(doc)
325
+
326
+ class_doc = "\n\n".join(doc_for_classes)
327
+
328
+ return copyright_for_markdown + doc_template + class_doc
329
+
330
+
331
+ def insert_model_in_doc_toc(
332
+ repo_path: Path, old_lowercase_name: str, new_lowercase_name: str, new_model_paper_name: str
333
+ ):
334
+ """
335
+ Insert the new model in the doc `_toctree.yaml`, in the same section as the old model.
336
+
337
+ Args:
338
+ old_lowercase_name (`str`):
339
+ The old lowercase model name.
340
+ new_lowercase_name (`str`):
341
+ The new lowercase model name.
342
+ new_model_paper_name (`str`):
343
+ The fully cased name (as in the official paper name) of the new model.
344
+ """
345
+ toc_file = repo_path / "docs" / "source" / "en" / "_toctree.yml"
346
+ with open(toc_file, "r") as f:
347
+ content = f.read()
348
+
349
+ toc_match = re.search(rf"- local: model_doc/{old_lowercase_name}\n {{8}}title: .*?\n", content)
350
+ if toc_match is None:
351
+ raise ValueError(f"Could not find TOC entry for {old_lowercase_name}")
352
+ old_model_toc = toc_match.group(0)
353
+ new_toc = f" - local: model_doc/{new_lowercase_name}\n title: {new_model_paper_name}\n"
354
+ add_content_to_file(
355
+ repo_path / "docs" / "source" / "en" / "_toctree.yml", new_content=new_toc, add_after=old_model_toc
356
+ )
357
+
358
+
359
+ def create_init_file(old_lowercase_name: str, new_lowercase_name: str, filenames_to_add: list[tuple[str, bool]]):
360
+ """
361
+ Create the `__init__.py` file to add in the new model folder.
362
+
363
+ Args:
364
+ old_lowercase_name (`str`):
365
+ The old lowercase model name.
366
+ new_lowercase_name (`str`):
367
+ The new lowercase model name.
368
+ filenames_to_add (`list[tuple[str, bool]]`):
369
+ A list of tuples of all potential filenames to add for a new model, along a boolean flag describing if we
370
+ should add this file or not. For example, [(`modeling_xxx.px`, True), (`configuration_xxx.py`, True), (`tokenization_xxx.py`, False),...]
371
+ """
372
+ filenames_to_add = [
373
+ (filename.replace(old_lowercase_name, new_lowercase_name).replace(".py", ""), to_add)
374
+ for filename, to_add in filenames_to_add
375
+ ]
376
+ imports = "\n ".join(f"from .{file} import *" for file, to_add in filenames_to_add if to_add)
377
+ init_file = COPYRIGHT + textwrap.dedent(
378
+ f"""
379
+ from typing import TYPE_CHECKING
380
+
381
+ from ...utils import _LazyModule
382
+ from ...utils.import_utils import define_import_structure
383
+
384
+
385
+ if TYPE_CHECKING:
386
+ {imports}
387
+ else:
388
+ import sys
389
+
390
+ _file = globals()["__file__"]
391
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
392
+ """
393
+ )
394
+ return init_file
395
+
396
+
397
+ def find_all_classes_from_file(module_name: str) -> set:
398
+ """
399
+ Find the name of all classes defined in `module_name`, including public ones (defined in `__all__`).
400
+
401
+ Args:
402
+ module_name (`str`):
403
+ The full path to the python module from which to extract classes.
404
+ """
405
+ with open(module_name, "r", encoding="utf-8") as file:
406
+ source_code = file.read()
407
+ module = cst.parse_module(source_code)
408
+ visitor = ClassFinder()
409
+ module.visit(visitor)
410
+ return visitor.classes, visitor.public_classes
411
+
412
+
413
+ def find_modular_structure(
414
+ module_name: Path, old_model_infos: ModelInfos, new_cased_name: str
415
+ ) -> tuple[str, str, list]:
416
+ """
417
+ Extract the modular structure that will be needed to copy a file `module_name` using modular.
418
+
419
+ Args:
420
+ module_name (`str`):
421
+ The full path to the python module to copy with modular.
422
+ old_model_infos (`ModelInfos`):
423
+ The structure containing the class information of the old model.
424
+ new_cased_name (`str`):
425
+ The new cased model name.
426
+ """
427
+ all_classes, public_classes = find_all_classes_from_file(module_name)
428
+ import_location = ".".join(module_name.parts[-2:]).replace(".py", "")
429
+ old_cased_name = old_model_infos.camelcase_name
430
+ imports = f"from ..{import_location} import {', '.join(class_ for class_ in all_classes)}"
431
+ modular_classes = "\n\n".join(
432
+ f"class {class_.replace(old_cased_name, new_cased_name)}({class_}):\n pass" for class_ in all_classes
433
+ )
434
+ public_classes = [class_.replace(old_cased_name, new_cased_name) for class_ in public_classes]
435
+ return imports, modular_classes, public_classes
436
+
437
+
438
+ def create_modular_file(
439
+ repo_path: Path,
440
+ old_model_infos: ModelInfos,
441
+ new_lowercase_name: str,
442
+ filenames_to_add: list[tuple[str, bool]],
443
+ ) -> str:
444
+ """
445
+ Create a new modular file which will copy the old model, based on the new name and the different filenames
446
+ (modules) to add.
447
+
448
+ Args:
449
+ old_model_infos (`ModelInfos`):
450
+ The structure containing the class information of the old model.
451
+ new_lowercase_name (`str`):
452
+ The new lowercase model name.
453
+ filenames_to_add (`list[tuple[str, bool]]`):
454
+ A list of tuples of all potential filenames to add for a new model, along a boolean flag describing if we
455
+ should add this file or not. For example, [(`modeling_xxx.px`, True), (`configuration_xxx.py`, True), (`tokenization_xxx.py`, False),...]
456
+ """
457
+ new_cased_name = "".join(x.title() for x in new_lowercase_name.replace("-", "_").split("_"))
458
+ old_lowercase_name = old_model_infos.lowercase_name
459
+ old_folder_root = repo_path / "src" / "transformers" / "models" / old_lowercase_name
460
+
461
+ # Construct the modular file from the original (old) model, by subclassing each class
462
+ all_imports = ""
463
+ all_bodies = ""
464
+ all_public_classes = []
465
+ for filename, to_add in filenames_to_add:
466
+ if to_add:
467
+ imports, body, public_classes = find_modular_structure(
468
+ old_folder_root / filename, old_model_infos, new_cased_name
469
+ )
470
+ all_imports += f"\n{imports}"
471
+ all_bodies += f"\n\n{body}"
472
+ all_public_classes.extend(public_classes)
473
+
474
+ # Create the __all__ assignment
475
+ public_classes_formatted = "\n ".join(f"{public_class}," for public_class in all_public_classes)
476
+ all_statement = textwrap.dedent(
477
+ f"""
478
+
479
+ __all__ = [
480
+ {public_classes_formatted}
481
+ ]
482
+ """
483
+ )
484
+ # Create the whole modular file
485
+ modular_file = COPYRIGHT + all_imports + all_bodies + all_statement
486
+ # Remove outer explicit quotes "" around the public class names before returning them
487
+ all_public_classes = [public_class.replace('"', "") for public_class in all_public_classes]
488
+ return modular_file, all_public_classes
489
+
490
+
491
+ def create_test_files(
492
+ repo_path: Path, old_model_infos: ModelInfos, new_lowercase_name, filenames_to_add: list[tuple[str, bool]]
493
+ ):
494
+ """
495
+ Create the test files for the new model. It basically copies over the old test files and adjust the class names.
496
+
497
+ Args:
498
+ old_model_infos (`ModelInfos`):
499
+ The structure containing the class information of the old model.
500
+ new_lowercase_name (`str`):
501
+ The new lowercase model name.
502
+ filenames_to_add (`list[tuple[str, bool]]`):
503
+ A list of tuples of all potential filenames to add for a new model, along a boolean flag describing if we
504
+ should add this file or not. For example, [(`modeling_xxx.px`, True), (`configuration_xxx.py`, True), (`tokenization_xxx.py`, False),...]
505
+ """
506
+ new_cased_name = "".join(x.title() for x in new_lowercase_name.replace("-", "_").split("_"))
507
+ old_lowercase_name = old_model_infos.lowercase_name
508
+ old_cased_name = old_model_infos.camelcase_name
509
+ filenames_to_add = [
510
+ ("test_" + filename.replace(old_lowercase_name, new_lowercase_name), to_add)
511
+ for filename, to_add in filenames_to_add[1:]
512
+ ]
513
+ # fast tokenizer/image processor have the same test files as normal ones
514
+ corrected_filenames_to_add = []
515
+ for file, to_add in filenames_to_add:
516
+ if re.search(rf"test_(?:tokenization)|(?:image_processing)_{new_lowercase_name}_fast.py", file):
517
+ previous_file, previous_to_add = corrected_filenames_to_add[-1]
518
+ corrected_filenames_to_add[-1] = (previous_file, previous_to_add or to_add)
519
+ else:
520
+ corrected_filenames_to_add.append((file, to_add))
521
+
522
+ test_files = {}
523
+ for new_file, to_add in corrected_filenames_to_add:
524
+ if to_add:
525
+ original_test_file = new_file.replace(new_lowercase_name, old_lowercase_name)
526
+ original_test_path = repo_path / "tests" / "models" / old_lowercase_name / original_test_file
527
+ # Sometimes, tests may not exist
528
+ if not original_test_path.is_file():
529
+ continue
530
+ with open(original_test_path, "r") as f:
531
+ test_code = f.read()
532
+ # Remove old copyright and add new one
533
+ test_lines = test_code.split("\n")
534
+ idx = 0
535
+ while test_lines[idx].startswith("#"):
536
+ idx += 1
537
+ test_code = COPYRIGHT + "\n".join(test_lines[idx:])
538
+ test_files[new_file] = test_code.replace(old_cased_name, new_cased_name)
539
+
540
+ return test_files
541
+
542
+
543
+ def _add_new_model_like_internal(
544
+ repo_path: Path,
545
+ old_model_infos: ModelInfos,
546
+ new_lowercase_name: str,
547
+ new_model_paper_name: str,
548
+ filenames_to_add: list[tuple[str, bool]],
549
+ ):
550
+ """
551
+ Creates a new model module like a given model of the Transformers library.
552
+
553
+ Args:
554
+ repo_path (`Path`):
555
+ The path to the root of the Transformers repository.
556
+ old_model_infos (`ModelInfos`):
557
+ The structure containing the class information of the old model.
558
+ new_lowercase_name (`str`):
559
+ The new lowercase model name.
560
+ new_model_paper_name (`str`):
561
+ The fully cased name (as in the official paper name) of the new model.
562
+ filenames_to_add (`list[tuple[str, bool]]`):
563
+ A list of tuples of all potential filenames to add for a new model, along a boolean flag describing if we
564
+ should add this file or not. For example, [(`modeling_xxx.px`, True), (`configuration_xxx.py`, True), (`tokenization_xxx.py`, False),...]
565
+ """
566
+ # As the import was protected, raise if not present (as it's actually a hard dependency for this command)
567
+ if not is_libcst_available():
568
+ raise ValueError("You need to install `libcst` to run this command -> `pip install libcst`")
569
+
570
+ old_lowercase_name = old_model_infos.lowercase_name
571
+
572
+ # 1. We create the folder for our new model
573
+ new_module_folder = repo_path / "src" / "transformers" / "models" / new_lowercase_name
574
+ os.makedirs(new_module_folder, exist_ok=True)
575
+
576
+ # 2. Create and add the modular file
577
+ modular_file, public_classes = create_modular_file(
578
+ repo_path, old_model_infos, new_lowercase_name, filenames_to_add
579
+ )
580
+ with open(new_module_folder / f"modular_{new_lowercase_name}.py", "w") as f:
581
+ f.write(modular_file)
582
+
583
+ # 3. Create and add the __init__.py
584
+ init_file = create_init_file(old_lowercase_name, new_lowercase_name, filenames_to_add)
585
+ with open(new_module_folder / "__init__.py", "w") as f:
586
+ f.write(init_file)
587
+
588
+ # 4. Add new model to the models init
589
+ add_content_to_file(
590
+ repo_path / "src" / "transformers" / "models" / "__init__.py",
591
+ new_content=f" from .{new_lowercase_name} import *\n",
592
+ add_after="if TYPE_CHECKING:\n",
593
+ )
594
+
595
+ # 5. Add model to auto mappings
596
+ add_model_to_auto_mappings(repo_path, old_model_infos, new_lowercase_name, new_model_paper_name, filenames_to_add)
597
+
598
+ # 6. Add test files
599
+ tests_folder = repo_path / "tests" / "models" / new_lowercase_name
600
+ os.makedirs(tests_folder, exist_ok=True)
601
+ # Add empty __init__.py
602
+ with open(tests_folder / "__init__.py", "w"):
603
+ pass
604
+ test_files = create_test_files(repo_path, old_model_infos, new_lowercase_name, filenames_to_add)
605
+ for filename, content in test_files.items():
606
+ with open(tests_folder / filename, "w") as f:
607
+ f.write(content)
608
+
609
+ # 7. Add doc file
610
+ doc_file = create_doc_file(new_model_paper_name, public_classes)
611
+ with open(repo_path / "docs" / "source" / "en" / "model_doc" / f"{new_lowercase_name}.md", "w") as f:
612
+ f.write(doc_file)
613
+ insert_model_in_doc_toc(repo_path, old_lowercase_name, new_lowercase_name, new_model_paper_name)
614
+
615
+ # 9. Run linters
616
+ model_init_file = repo_path / "src" / "transformers" / "models" / "__init__.py"
617
+ subprocess.run(
618
+ ["ruff", "check", new_module_folder, tests_folder, model_init_file, "--fix"],
619
+ cwd=repo_path,
620
+ stdout=subprocess.DEVNULL,
621
+ )
622
+ subprocess.run(
623
+ ["ruff", "format", new_module_folder, tests_folder, model_init_file],
624
+ cwd=repo_path,
625
+ stdout=subprocess.DEVNULL,
626
+ )
627
+ subprocess.run(
628
+ ["python", "utils/check_doc_toc.py", "--fix_and_overwrite"], cwd=repo_path, stdout=subprocess.DEVNULL
629
+ )
630
+ subprocess.run(["python", "utils/sort_auto_mappings.py"], cwd=repo_path, stdout=subprocess.DEVNULL)
631
+
632
+ # 10. Run the modular conversion
633
+ subprocess.run(
634
+ ["python", "utils/modular_model_converter.py", new_lowercase_name], cwd=repo_path, stdout=subprocess.DEVNULL
635
+ )
636
+
637
+
638
+ def get_user_field(
639
+ question: str,
640
+ default_value: str | None = None,
641
+ convert_to: Callable | None = None,
642
+ fallback_message: str | None = None,
643
+ ) -> Any:
644
+ """
645
+ A utility function that asks a question to the user to get an answer, potentially looping until it gets a valid
646
+ answer.
647
+
648
+ Args:
649
+ question (`str`):
650
+ The question to ask the user.
651
+ default_value (`str`, *optional*):
652
+ A potential default value that will be used when the answer is empty.
653
+ convert_to (`Callable`, *optional*):
654
+ If set, the answer will be passed to this function. If this function raises an error on the provided
655
+ answer, the question will be asked again.
656
+ fallback_message (`str`, *optional*):
657
+ A message that will be displayed each time the question is asked again to the user.
658
+
659
+ Returns:
660
+ `Any`: The answer provided by the user (or the default), passed through the potential conversion function.
661
+ """
662
+ if not question.endswith(" "):
663
+ question = question + " "
664
+ if default_value is not None:
665
+ question = f"{question} [{default_value}] "
666
+
667
+ valid_answer = False
668
+ while not valid_answer:
669
+ answer = input(question)
670
+ if default_value is not None and len(answer) == 0:
671
+ answer = default_value
672
+ if convert_to is not None:
673
+ try:
674
+ answer = convert_to(answer)
675
+ valid_answer = True
676
+ except Exception:
677
+ valid_answer = False
678
+ else:
679
+ valid_answer = True
680
+
681
+ if not valid_answer:
682
+ print(fallback_message)
683
+
684
+ return answer
685
+
686
+
687
+ def convert_to_bool(x: str) -> bool:
688
+ """
689
+ Converts a string to a bool.
690
+ """
691
+ if x.lower() in ["1", "y", "yes", "true"]:
692
+ return True
693
+ if x.lower() in ["0", "n", "no", "false"]:
694
+ return False
695
+ raise ValueError(f"{x} is not a value that can be converted to a bool.")
696
+
697
+
698
+ def get_user_input():
699
+ """
700
+ Ask the user for the necessary inputs to add the new model.
701
+ """
702
+ from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
703
+
704
+ model_types = list(CONFIG_MAPPING_NAMES.keys())
705
+
706
+ # Get old model type
707
+ valid_model_type = False
708
+ while not valid_model_type:
709
+ old_model_type = input(
710
+ "What model would you like to duplicate? Please provide it as lowercase, e.g. `llama`): "
711
+ )
712
+ if old_model_type in model_types:
713
+ valid_model_type = True
714
+ else:
715
+ print(f"{old_model_type} is not a valid model type.")
716
+ near_choices = difflib.get_close_matches(old_model_type, model_types)
717
+ if len(near_choices) >= 1:
718
+ if len(near_choices) > 1:
719
+ near_choices = " or ".join(near_choices)
720
+ print(f"Did you mean {near_choices}?")
721
+
722
+ old_model_infos = ModelInfos(old_model_type)
723
+
724
+ # Ask for the new model name
725
+ new_lowercase_name = get_user_field(
726
+ "What is the new model name? Please provide it as snake lowercase, e.g. `new_model`?"
727
+ )
728
+ new_model_paper_name = get_user_field(
729
+ "What is the fully cased name you would like to appear in the doc (e.g. `NeW ModEl`)? ",
730
+ default_value="".join(x.title() for x in new_lowercase_name.split("_")),
731
+ )
732
+
733
+ # Ask if we want to add individual processor classes as well
734
+ add_tokenizer = False
735
+ add_fast_tokenizer = False
736
+ add_image_processor = False
737
+ add_video_processor = False
738
+ add_feature_extractor = False
739
+ add_processor = False
740
+ if old_model_infos.tokenizer_class is not None:
741
+ add_tokenizer = get_user_field(
742
+ f"Do you want to create a new tokenizer? If `no`, it will use the same as {old_model_type} (y/n)?",
743
+ convert_to=convert_to_bool,
744
+ fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
745
+ )
746
+ if old_model_infos.fast_tokenizer_class is not None:
747
+ add_fast_tokenizer = get_user_field(
748
+ f"Do you want to create a new fast tokenizer? If `no`, it will use the same as {old_model_type} (y/n)?",
749
+ convert_to=convert_to_bool,
750
+ fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
751
+ )
752
+ if old_model_infos.image_processor_classes is not None:
753
+ add_image_processor = get_user_field(
754
+ f"Do you want to create a new image processor? If `no`, it will use the same as {old_model_type} (y/n)?",
755
+ convert_to=convert_to_bool,
756
+ fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
757
+ )
758
+ if old_model_infos.video_processor_class is not None:
759
+ add_video_processor = get_user_field(
760
+ f"Do you want to create a new video processor? If `no`, it will use the same as {old_model_type} (y/n)?",
761
+ convert_to=convert_to_bool,
762
+ fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
763
+ )
764
+ if old_model_infos.feature_extractor_class is not None:
765
+ add_feature_extractor = get_user_field(
766
+ f"Do you want to create a new feature extractor? If `no`, it will use the same as {old_model_type} (y/n)?",
767
+ convert_to=convert_to_bool,
768
+ fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
769
+ )
770
+ if old_model_infos.processor_class is not None:
771
+ add_processor = get_user_field(
772
+ f"Do you want to create a new processor? If `no`, it will use the same as {old_model_type} (y/n)?",
773
+ convert_to=convert_to_bool,
774
+ fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
775
+ )
776
+
777
+ old_lowercase_name = old_model_infos.lowercase_name
778
+ # A list of the old filenames, along whether we should copy them or not
779
+ filenames_to_add = (
780
+ (f"configuration_{old_lowercase_name}.py", True),
781
+ (f"modeling_{old_lowercase_name}.py", True),
782
+ (f"tokenization_{old_lowercase_name}.py", add_tokenizer),
783
+ (f"tokenization_{old_lowercase_name}_fast.py", add_fast_tokenizer),
784
+ (f"image_processing_{old_lowercase_name}.py", add_image_processor),
785
+ (f"video_processing_{old_lowercase_name}.py", add_video_processor),
786
+ (f"feature_extraction_{old_lowercase_name}.py", add_feature_extractor),
787
+ (f"processing_{old_lowercase_name}.py", add_processor),
788
+ )
789
+
790
+ return old_model_infos, new_lowercase_name, new_model_paper_name, filenames_to_add
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/chat.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import asyncio
15
+ import json
16
+ import os
17
+ import platform
18
+ import re
19
+ import string
20
+ import time
21
+ from collections.abc import AsyncIterator, Awaitable
22
+ from typing import Annotated, Any
23
+ from urllib.parse import urljoin, urlparse
24
+
25
+ import httpx
26
+ import requests
27
+ import typer
28
+ import yaml
29
+ from huggingface_hub import AsyncInferenceClient, ChatCompletionStreamOutput
30
+
31
+ from transformers import GenerationConfig
32
+ from transformers.utils import is_rich_available
33
+
34
+
35
+ try:
36
+ import readline # noqa importing this enables GNU readline capabilities
37
+ except ImportError:
38
+ # some platforms may not support readline: https://docs.python.org/3/library/readline.html
39
+ pass
40
+
41
+ if platform.system() != "Windows":
42
+ import pwd
43
+
44
+ if is_rich_available():
45
+ from rich import filesize
46
+ from rich.console import Console
47
+ from rich.live import Live
48
+ from rich.markdown import Markdown
49
+ from rich.progress import BarColumn, Progress, ProgressColumn, TextColumn, TimeElapsedColumn
50
+ from rich.text import Text
51
+
52
+ DEFAULT_HTTP_ENDPOINT = {"hostname": "localhost", "port": 8000}
53
+ ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
54
+ ALLOWED_VALUE_CHARS = set(
55
+ string.ascii_letters + string.digits + string.whitespace + r".!\"#$%&'()*+,\-/:<=>?@[]^_`{|}~"
56
+ )
57
+
58
+ DEFAULT_EXAMPLES = {
59
+ "llama": {"text": "There is a Llama in my lawn, how can I get rid of it?"},
60
+ "code": {
61
+ "text": (
62
+ "Write a Python function that integrates any Python function f(x) numerically over an arbitrary "
63
+ "interval [x_start, x_end]."
64
+ ),
65
+ },
66
+ "helicopter": {"text": "How many helicopters can a human eat in one sitting?"},
67
+ "numbers": {"text": "Count to 10 but skip every number ending with an 'e'"},
68
+ "birds": {"text": "Why aren't birds real?"},
69
+ "socks": {"text": "Why is it important to eat socks after meditating?"},
70
+ "numbers2": {"text": "Which number is larger, 9.9 or 9.11?"},
71
+ }
72
+
73
+ # Printed at the start of a chat session
74
+ HELP_STRING_MINIMAL = """
75
+
76
+ **TRANSFORMERS CHAT INTERFACE**
77
+
78
+ Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
79
+ - **!help**: shows all available commands (set generation settings, save chat, etc.)
80
+ - **!status**: shows the current status of the model and generation settings
81
+ - **!clear**: clears the current conversation and starts a new one
82
+ - **!exit**: closes the interface
83
+ """
84
+
85
+
86
+ # Printed when the user types `help` in the chat session
87
+ HELP_STRING = f"""
88
+
89
+ **TRANSFORMERS CHAT INTERFACE HELP**
90
+
91
+ Full command list:
92
+ - **!help**: shows this help message
93
+ - **!clear**: clears the current conversation and starts a new one
94
+ - **!status**: shows the current status of the model and generation settings
95
+ - **!example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input.
96
+ Available example names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
97
+ - **!set {{ARG_1}}={{VALUE_1}} {{ARG_2}}={{VALUE_2}}** ...: changes the system prompt or generation settings (multiple
98
+ settings are separated by a space). Accepts the same flags and format as the `generate_flags` CLI argument.
99
+ If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options
100
+ - **!save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
101
+ `./chat_history/{{MODEL_ID}}/chat_{{DATETIME}}.yaml` or `{{SAVE_NAME}}` if provided
102
+ - **!exit**: closes the interface
103
+ """
104
+
105
+
106
+ class RichInterface:
107
+ def __init__(self, model_id: str, user_id: str, base_url: str):
108
+ self._console = Console()
109
+ self.model_id = model_id
110
+ self.user_id = user_id
111
+ self.base_url = base_url
112
+
113
+ async def stream_output(
114
+ self, stream: Awaitable[AsyncIterator[ChatCompletionStreamOutput]]
115
+ ) -> tuple[str, str | Any | None]:
116
+ self._console.print(f"[bold blue]<{self.model_id}>:")
117
+ with Live(console=self._console, refresh_per_second=4) as live:
118
+ text = ""
119
+ completion_tokens = 0
120
+ start_time = time.time()
121
+ finish_reason: str | None = None
122
+ async for token in await stream:
123
+ outputs = token.choices[0].delta.content
124
+ finish_reason = getattr(token.choices[0], "finish_reason", finish_reason)
125
+
126
+ usage = getattr(token, "usage", None)
127
+ if usage is not None:
128
+ completion_tokens = getattr(usage, "completion_tokens", completion_tokens)
129
+
130
+ if not outputs:
131
+ continue
132
+
133
+ # Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown.
134
+ # It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>)
135
+ outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs)
136
+
137
+ text += outputs
138
+ # Render the accumulated text as Markdown
139
+ # NOTE: this is a workaround for the rendering "unstandard markdown"
140
+ # in rich. The chatbots output treat "\n" as a new line for
141
+ # better compatibility with real-world text. However, rendering
142
+ # in markdown would break the format. It is because standard markdown
143
+ # treat a single "\n" in normal text as a space.
144
+ # Our workaround is adding two spaces at the end of each line.
145
+ # This is not a perfect solution, as it would
146
+ # introduce trailing spaces (only) in code block, but it works well
147
+ # especially for console output, because in general the console does not
148
+ # care about trailing spaces.
149
+
150
+ lines = []
151
+ for line in text.splitlines():
152
+ lines.append(line)
153
+ if line.startswith("```"):
154
+ # Code block marker - do not add trailing spaces, as it would
155
+ # break the syntax highlighting
156
+ lines.append("\n")
157
+ else:
158
+ lines.append(" \n")
159
+
160
+ markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
161
+
162
+ # Update the Live console output
163
+ live.update(markdown, refresh=True)
164
+
165
+ elapsed = time.time() - start_time
166
+ if elapsed > 0 and completion_tokens > 0:
167
+ tok_per_sec = completion_tokens / elapsed
168
+ self._console.print()
169
+ self._console.print(f"[dim]{completion_tokens} tokens in {elapsed:.1f}s ({tok_per_sec:.1f} tok/s)[/dim]")
170
+ self._console.print()
171
+
172
+ return text, finish_reason
173
+
174
+ def input(self) -> str:
175
+ """Gets user input from the console."""
176
+ input = self._console.input(f"[bold red]<{self.user_id}>:\n")
177
+ self._console.print()
178
+ return input
179
+
180
+ def clear(self):
181
+ """Clears the console."""
182
+ self._console.clear()
183
+
184
+ def print_user_message(self, text: str):
185
+ """Prints a user message to the console."""
186
+ self._console.print(f"[bold red]<{self.user_id}>:[/ bold red]\n{text}")
187
+ self._console.print()
188
+
189
+ def print_color(self, text: str, color: str):
190
+ """Prints text in a given color to the console."""
191
+ self._console.print(f"[bold {color}]{text}")
192
+ self._console.print()
193
+
194
+ def confirm(self, message: str, default: bool = False) -> bool:
195
+ """Displays a yes/no prompt to the user, returning True for confirmation."""
196
+ default_hint = "Y/n" if default else "y/N"
197
+ response = self._console.input(f"[bold yellow]{message} ({default_hint}): ")
198
+ self._console.print()
199
+
200
+ response = response.strip().lower()
201
+ if not response:
202
+ return default
203
+
204
+ return response in {"y", "yes"}
205
+
206
+ def print_help(self, minimal: bool = False):
207
+ """Prints the help message to the console."""
208
+ self._console.print(Markdown(HELP_STRING_MINIMAL if minimal else HELP_STRING))
209
+ self._console.print()
210
+
211
+ def print_model_load(self, model: str):
212
+ response = requests.post(f"{self.base_url.rstrip('/')}/load_model", json={"model": model}, stream=True)
213
+ response.raise_for_status()
214
+
215
+ class StatsColumn(ProgressColumn):
216
+ def render(self, task):
217
+ if not task.total:
218
+ return Text("")
219
+
220
+ if task.fields.get("unit") == "bytes":
221
+ done = filesize.decimal(int(task.completed))
222
+ tot = filesize.decimal(int(task.total))
223
+ speed = f" {filesize.decimal(int(task.speed))}/s" if task.speed else ""
224
+
225
+ if task.time_remaining is not None:
226
+ eta = f" {int(task.time_remaining // 60)}:{int(task.time_remaining % 60):02d}"
227
+ else:
228
+ eta = ""
229
+
230
+ return Text(f"{done}/{tot}{speed}{eta}", style="progress.download")
231
+ return Text(f"{int(task.completed)}/{int(task.total)}")
232
+
233
+ stage_labels = {
234
+ "processor": "Loading processor",
235
+ "config": "Loading config",
236
+ "download": "Downloading files",
237
+ "weights": "Loading into memory",
238
+ }
239
+
240
+ # Include the model name prefix in descriptions only when the terminal is wide enough.
241
+ # The bar, stats, and elapsed columns need ~70 chars; the model prefix needs len(model)+5.
242
+ show_model_prefix = self._console.width >= len(model) + 5 + 70
243
+
244
+ def _label(stage_key):
245
+ stage_text = stage_labels.get(stage_key, stage_key)
246
+ if show_model_prefix:
247
+ return f"{model} → {stage_text}"
248
+ return stage_text
249
+
250
+ progress = Progress(
251
+ TextColumn("[bold]{task.description}"),
252
+ BarColumn(bar_width=40),
253
+ StatsColumn(),
254
+ TimeElapsedColumn(),
255
+ console=self._console,
256
+ )
257
+ task_id = progress.add_task(_label("processor"), total=None)
258
+ cached = False
259
+
260
+ with Live(progress, console=self._console, transient=True):
261
+ for line in response.iter_lines():
262
+ if not line or not line.startswith(b"data: "):
263
+ continue
264
+ event = json.loads(line[6:])
265
+ status = event.get("status")
266
+
267
+ if status == "ready":
268
+ cached = event.get("cached", False)
269
+ break
270
+
271
+ if status == "error":
272
+ raise RuntimeError(event.get("message", "Unknown error"))
273
+
274
+ if status == "loading":
275
+ stage = event.get("stage")
276
+ prog = event.get("progress")
277
+ label = _label(stage)
278
+
279
+ if prog:
280
+ unit = "bytes" if stage == "download" else "items"
281
+ progress.update(
282
+ task_id, description=label, completed=prog["current"], total=prog.get("total"), unit=unit
283
+ )
284
+ else:
285
+ progress.update(task_id, description=label, completed=0, total=None)
286
+
287
+ if cached:
288
+ self._console.print(Markdown(f"_*{model} was already loaded.*_"))
289
+ else:
290
+ self._console.print(Markdown(f"_*{model} is warm.*_"))
291
+ self._console.print()
292
+
293
+ def print_status(self, config: GenerationConfig):
294
+ """Prints the status of the model and generation settings to the console."""
295
+ self._console.print(f"[bold blue]Model: {self.model_id}\n")
296
+ self._console.print(f"[bold blue]{config}")
297
+ self._console.print()
298
+
299
+
300
+ class Chat:
301
+ """Chat with a model from the command line."""
302
+
303
+ # Defining a class to help with internal state but in practice it's just a method to call
304
+ # TODO: refactor into a proper module with helpers + 1 main method
305
+ def __init__(
306
+ self,
307
+ model_id: Annotated[str, typer.Argument(help="ID of the model to use (e.g. 'HuggingFaceTB/SmolLM3-3B').")],
308
+ base_url: Annotated[
309
+ str | None, typer.Argument(help="Base url to connect to (e.g. http://localhost:8000/v1).")
310
+ ] = f"http://{DEFAULT_HTTP_ENDPOINT['hostname']}:{DEFAULT_HTTP_ENDPOINT['port']}",
311
+ generate_flags: Annotated[
312
+ list[str] | None,
313
+ typer.Argument(
314
+ help=(
315
+ "Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, "
316
+ "and lists of integers, more advanced parameterization should be set through --generation-config. "
317
+ "Example: `transformers chat <base_url> <model_id> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. "
318
+ "If you're a new user, check this basic flag guide: "
319
+ "https://huggingface.co/docs/transformers/llm_tutorial#common-options"
320
+ )
321
+ ),
322
+ ] = None,
323
+ # General settings
324
+ user: Annotated[
325
+ str | None,
326
+ typer.Option(help="Username to display in chat interface. Defaults to the current user's name."),
327
+ ] = None,
328
+ system_prompt: Annotated[str | None, typer.Option(help="System prompt.")] = None,
329
+ save_folder: Annotated[str, typer.Option(help="Folder to save chat history.")] = "./chat_history/",
330
+ examples_path: Annotated[str | None, typer.Option(help="Path to a yaml file with examples.")] = None,
331
+ # Generation settings
332
+ generation_config: Annotated[
333
+ str | None,
334
+ typer.Option(
335
+ help="Path to a local generation config file or to a HuggingFace repo containing a `generation_config.json` file. Other generation settings passed as CLI arguments will be applied on top of this generation config."
336
+ ),
337
+ ] = None,
338
+ ) -> None:
339
+ """Chat with a model from the command line."""
340
+ self.base_url = base_url
341
+
342
+ parsed = urlparse(self.base_url)
343
+ if parsed.hostname == DEFAULT_HTTP_ENDPOINT["hostname"] and parsed.port == DEFAULT_HTTP_ENDPOINT["port"]:
344
+ self.check_health(self.base_url)
345
+
346
+ self.model_id = model_id
347
+ self.system_prompt = system_prompt
348
+ self.save_folder = save_folder
349
+
350
+ # Generation settings
351
+ config = load_generation_config(generation_config)
352
+ config.update(do_sample=True, max_new_tokens=256) # some default values
353
+ config.update(**parse_generate_flags(generate_flags))
354
+ self.config = config
355
+
356
+ self.settings = {"base_url": base_url, "model_id": model_id, "config": self.config.to_dict()}
357
+
358
+ # User settings
359
+ self.user = user if user is not None else get_username()
360
+
361
+ # Load examples
362
+ if examples_path:
363
+ with open(examples_path) as f:
364
+ self.examples = yaml.safe_load(f)
365
+ else:
366
+ self.examples = DEFAULT_EXAMPLES
367
+
368
+ # Check requirements
369
+ if not is_rich_available():
370
+ raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
371
+
372
+ # Run chat session
373
+ asyncio.run(self._inner_run())
374
+
375
+ @staticmethod
376
+ def check_health(url):
377
+ health_url = urljoin(url + "/", "health")
378
+ try:
379
+ output = httpx.get(health_url)
380
+ if output.status_code != 200:
381
+ raise ValueError(
382
+ f"The server running on {url} returned status code {output.status_code} on health check (/health)."
383
+ )
384
+ except httpx.ConnectError:
385
+ raise ValueError(
386
+ f"No server currently running on {url}. To run a local server, please run `transformers serve` in a"
387
+ f"separate shell. Find more information here: https://huggingface.co/docs/transformers/serving"
388
+ )
389
+
390
+ return True
391
+
392
+ def handle_non_exit_user_commands(
393
+ self,
394
+ user_input: str,
395
+ interface: RichInterface,
396
+ examples: dict[str, dict[str, str]],
397
+ config: GenerationConfig,
398
+ chat: list[dict],
399
+ ) -> tuple[list[dict], GenerationConfig]:
400
+ """
401
+ Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the
402
+ generation config (e.g. set a new flag).
403
+ """
404
+ valid_command = True
405
+
406
+ if user_input == "!clear":
407
+ chat = new_chat_history(self.system_prompt)
408
+ interface.clear()
409
+
410
+ elif user_input == "!help":
411
+ interface.print_help()
412
+
413
+ elif user_input.startswith("!save") and len(user_input.split()) < 2:
414
+ split_input = user_input.split()
415
+ filename = (
416
+ split_input[1]
417
+ if len(split_input) == 2
418
+ else os.path.join(self.save_folder, self.model_id, f"chat_{time.strftime('%Y-%m-%d_%H-%M-%S')}.json")
419
+ )
420
+ save_chat(filename=filename, chat=chat, settings=self.settings)
421
+ interface.print_color(text=f"Chat saved to {filename}!", color="green")
422
+
423
+ elif user_input.startswith("!set"):
424
+ # splits the new args into a list of strings, each string being a `flag=value` pair (same format as
425
+ # `generate_flags`)
426
+ new_generate_flags = user_input[4:].strip()
427
+ new_generate_flags = new_generate_flags.split()
428
+ # sanity check: each member in the list must have an =
429
+ for flag in new_generate_flags:
430
+ if "=" not in flag:
431
+ interface.print_color(
432
+ text=(
433
+ f"Invalid flag format, missing `=` after `{flag}`. Please use the format "
434
+ "`arg_1=value_1 arg_2=value_2 ...`."
435
+ ),
436
+ color="red",
437
+ )
438
+ break
439
+ else:
440
+ # Update config from user flags
441
+ config.update(**parse_generate_flags(new_generate_flags))
442
+
443
+ elif user_input.startswith("!example") and len(user_input.split()) == 2:
444
+ example_name = user_input.split()[1]
445
+ if example_name in examples:
446
+ interface.clear()
447
+ chat = []
448
+ interface.print_user_message(examples[example_name]["text"])
449
+ chat.append({"role": "user", "content": examples[example_name]["text"]})
450
+ else:
451
+ example_error = (
452
+ f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
453
+ )
454
+ interface.print_color(text=example_error, color="red")
455
+
456
+ elif user_input == "!status":
457
+ interface.print_status(config=config)
458
+
459
+ else:
460
+ valid_command = False
461
+ interface.print_color(text=f"'{user_input}' is not a valid command. Showing help message.", color="red")
462
+ interface.print_help()
463
+
464
+ return chat, valid_command, config
465
+
466
+ async def _inner_run(self):
467
+ interface = RichInterface(model_id=self.model_id, user_id=self.user, base_url=self.base_url)
468
+ interface.clear()
469
+ chat = new_chat_history(self.system_prompt)
470
+
471
+ # Starts the session with a minimal help message at the top, so that a user doesn't get stuck
472
+ interface.print_help(minimal=True)
473
+ interface.print_model_load(self.model_id)
474
+
475
+ config = self.config
476
+
477
+ async with AsyncInferenceClient(base_url=self.base_url) as client:
478
+ pending_user_input: str | None = None
479
+ while True:
480
+ try:
481
+ if pending_user_input is not None:
482
+ user_input = pending_user_input
483
+ pending_user_input = None
484
+ interface.print_user_message(user_input)
485
+ else:
486
+ user_input = interface.input()
487
+
488
+ # User commands
489
+ if user_input == "!exit":
490
+ break
491
+
492
+ elif user_input == "!clear":
493
+ chat = new_chat_history(self.system_prompt)
494
+ interface.clear()
495
+ continue
496
+
497
+ elif user_input == "!help":
498
+ interface.print_help()
499
+ continue
500
+
501
+ elif user_input.startswith("!save") and len(user_input.split()) < 2:
502
+ split_input = user_input.split()
503
+ filename = (
504
+ split_input[1]
505
+ if len(split_input) == 2
506
+ else os.path.join(
507
+ self.save_folder, self.model_id, f"chat_{time.strftime('%Y-%m-%d_%H-%M-%S')}.json"
508
+ )
509
+ )
510
+ save_chat(filename=filename, chat=chat, settings=self.settings)
511
+ interface.print_color(text=f"Chat saved to {filename}!", color="green")
512
+ continue
513
+
514
+ elif user_input.startswith("!set"):
515
+ # splits the new args into a list of strings, each string being a `flag=value` pair (same format as
516
+ # `generate_flags`)
517
+ new_generate_flags = user_input[4:].strip()
518
+ new_generate_flags = new_generate_flags.split()
519
+ # sanity check: each member in the list must have an =
520
+ for flag in new_generate_flags:
521
+ if "=" not in flag:
522
+ interface.print_color(
523
+ text=(
524
+ f"Invalid flag format, missing `=` after `{flag}`. Please use the format "
525
+ "`arg_1=value_1 arg_2=value_2 ...`."
526
+ ),
527
+ color="red",
528
+ )
529
+ break
530
+ else:
531
+ # Update config from user flags
532
+ config.update(**parse_generate_flags(new_generate_flags))
533
+ continue
534
+
535
+ elif user_input.startswith("!example") and len(user_input.split()) == 2:
536
+ example_name = user_input.split()[1]
537
+ if example_name in self.examples:
538
+ interface.clear()
539
+ chat = []
540
+ interface.print_user_message(self.examples[example_name]["text"])
541
+ chat.append({"role": "user", "content": self.examples[example_name]["text"]})
542
+ else:
543
+ example_error = f"Example {example_name} not found in list of available examples: {list(self.examples.keys())}."
544
+ interface.print_color(text=example_error, color="red")
545
+
546
+ elif user_input == "!status":
547
+ interface.print_status(config=config)
548
+ continue
549
+
550
+ elif user_input.startswith("!"):
551
+ interface.print_color(
552
+ text=f"'{user_input}' is not a valid command. Showing help message.", color="red"
553
+ )
554
+ interface.print_help()
555
+ continue
556
+
557
+ else:
558
+ chat.append({"role": "user", "content": user_input})
559
+
560
+ extra_body = {
561
+ "generation_config": config.to_json_string(),
562
+ "model": self.model_id,
563
+ }
564
+
565
+ stream = client.chat_completion(
566
+ chat,
567
+ stream=True,
568
+ model=self.model_id,
569
+ extra_body=extra_body,
570
+ )
571
+
572
+ model_output, finish_reason = await interface.stream_output(stream)
573
+
574
+ chat.append({"role": "assistant", "content": model_output})
575
+
576
+ if finish_reason == "length":
577
+ interface.print_color("Generation stopped after reaching the token limit.", "yellow")
578
+ if interface.confirm("Continue generating?"):
579
+ pending_user_input = "Please continue. Do not repeat text.”"
580
+ continue
581
+ except KeyboardInterrupt:
582
+ break
583
+
584
+
585
+ def load_generation_config(generation_config: str | None) -> GenerationConfig:
586
+ if generation_config is None:
587
+ return GenerationConfig()
588
+
589
+ if ".json" in generation_config: # is a local file
590
+ dirname = os.path.dirname(generation_config)
591
+ filename = os.path.basename(generation_config)
592
+ return GenerationConfig.from_pretrained(dirname, filename)
593
+ else:
594
+ return GenerationConfig.from_pretrained(generation_config)
595
+
596
+
597
+ def parse_generate_flags(generate_flags: list[str] | None) -> dict:
598
+ """Parses the generate flags from the user input into a dictionary of `generate` kwargs."""
599
+ if generate_flags is None or len(generate_flags) == 0:
600
+ return {}
601
+
602
+ # Assumption: `generate_flags` is a list of strings, each string being a `flag=value` pair, that can be parsed
603
+ # into a json string if we:
604
+ # 1. Add quotes around each flag name
605
+ generate_flags_as_dict = {'"' + flag.split("=")[0] + '"': flag.split("=")[1] for flag in generate_flags}
606
+
607
+ # 2. Handle types:
608
+ # 2. a. booleans should be lowercase, None should be null
609
+ generate_flags_as_dict = {
610
+ k: v.lower() if v.lower() in ["true", "false"] else v for k, v in generate_flags_as_dict.items()
611
+ }
612
+ generate_flags_as_dict = {k: "null" if v == "None" else v for k, v in generate_flags_as_dict.items()}
613
+
614
+ # 2. b. strings should be quoted
615
+ def is_number(s: str) -> bool:
616
+ # handle negative numbers
617
+ s = s.removeprefix("-")
618
+ return s.replace(".", "", 1).isdigit()
619
+
620
+ generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()}
621
+ # 2. c. [no processing needed] lists are lists of ints because `generate` doesn't take lists of strings :)
622
+ # We also mention in the help message that we only accept lists of ints for now.
623
+
624
+ # 3. Join the result into a comma separated string
625
+ generate_flags_string = ", ".join([f"{k}: {v}" for k, v in generate_flags_as_dict.items()])
626
+
627
+ # 4. Add the opening/closing brackets
628
+ generate_flags_string = "{" + generate_flags_string + "}"
629
+
630
+ # 5. Remove quotes around boolean/null and around lists
631
+ generate_flags_string = generate_flags_string.replace('"null"', "null")
632
+ generate_flags_string = generate_flags_string.replace('"true"', "true")
633
+ generate_flags_string = generate_flags_string.replace('"false"', "false")
634
+ generate_flags_string = generate_flags_string.replace('"[', "[")
635
+ generate_flags_string = generate_flags_string.replace(']"', "]")
636
+
637
+ # 6. Replace the `=` with `:`
638
+ generate_flags_string = generate_flags_string.replace("=", ":")
639
+
640
+ try:
641
+ processed_generate_flags = json.loads(generate_flags_string)
642
+ except json.JSONDecodeError:
643
+ raise ValueError(
644
+ "Failed to convert `generate_flags` into a valid JSON object."
645
+ "\n`generate_flags` = {generate_flags}"
646
+ "\nConverted JSON string = {generate_flags_string}"
647
+ )
648
+ return processed_generate_flags
649
+
650
+
651
+ def new_chat_history(system_prompt: str | None = None) -> list[dict]:
652
+ """Returns a new chat conversation."""
653
+ return [{"role": "system", "content": system_prompt}] if system_prompt else []
654
+
655
+
656
+ def save_chat(filename: str, chat: list[dict], settings: dict) -> str:
657
+ """Saves the chat history to a file."""
658
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
659
+ with open(filename, "w") as f:
660
+ json.dump({"settings": settings, "chat_history": chat}, f, indent=4)
661
+ return os.path.abspath(filename)
662
+
663
+
664
+ def get_username() -> str:
665
+ """Returns the username of the current user."""
666
+ if platform.system() == "Windows":
667
+ return os.getlogin()
668
+ else:
669
+ return pwd.getpwuid(os.getuid()).pw_name
670
+
671
+
672
+ if __name__ == "__main__":
673
+ Chat(model_id="meta-llama/Llama-3.2-3b-Instruct")
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/download.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Annotated
15
+
16
+ import typer
17
+
18
+
19
+ def download(
20
+ model_id: Annotated[str, typer.Argument(help="The model ID to download")],
21
+ cache_dir: Annotated[str | None, typer.Option(help="Directory where to save files.")] = None,
22
+ force_download: Annotated[
23
+ bool, typer.Option(help="If set, the files will be downloaded even if they are already cached locally.")
24
+ ] = False,
25
+ trust_remote_code: Annotated[
26
+ bool,
27
+ typer.Option(
28
+ help="Whether or not to allow for custom models defined on the Hub in their own modeling files. Use only if you've reviewed the code as it will execute on your local machine"
29
+ ),
30
+ ] = False,
31
+ ):
32
+ """Download a model and its tokenizer from the Hub."""
33
+ from ..models.auto import AutoModel, AutoTokenizer
34
+
35
+ AutoModel.from_pretrained(
36
+ model_id, cache_dir=cache_dir, force_download=force_download, trust_remote_code=trust_remote_code
37
+ )
38
+ AutoTokenizer.from_pretrained(
39
+ model_id, cache_dir=cache_dir, force_download=force_download, trust_remote_code=trust_remote_code
40
+ )
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/serve.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ CLI entry point for `transformers serve`.
16
+ """
17
+
18
+ import asyncio
19
+ import enum
20
+ import json
21
+ import threading
22
+ from typing import Annotated
23
+
24
+ import typer
25
+
26
+ from transformers.utils import logging
27
+ from transformers.utils.import_utils import is_serve_available
28
+
29
+ from .serving.utils import set_torch_seed
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class ReasoningMode(str, enum.Enum):
36
+ ON = "on"
37
+ OFF = "off"
38
+ AUTO = "auto"
39
+
40
+
41
+ class Serve:
42
+ def __init__(
43
+ self,
44
+ force_model: Annotated[str | None, typer.Argument(help="Model to preload and use for all requests.")] = None,
45
+ # Model options
46
+ continuous_batching: Annotated[
47
+ bool,
48
+ typer.Option(help="Enable continuous batching with paged attention. Configure with --cb-* flags."),
49
+ ] = False,
50
+ attn_implementation: Annotated[
51
+ str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).")
52
+ ] = None,
53
+ compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False,
54
+ quantization: Annotated[
55
+ str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.")
56
+ ] = None,
57
+ reasoning: Annotated[
58
+ ReasoningMode,
59
+ typer.Option(
60
+ help=(
61
+ "Reasoning mode. 'auto' uses the chat template default. Only applies to models that "
62
+ "support reasoning via their chat template (e.g. Qwen3, Gemma 4) — for other models "
63
+ "this flag has no effect."
64
+ )
65
+ ),
66
+ ] = ReasoningMode.AUTO,
67
+ chat_template_kwargs: Annotated[
68
+ str | None,
69
+ typer.Option(
70
+ help=(
71
+ "Default JSON kwargs forwarded to apply_chat_template "
72
+ "(e.g. '{\"enable_thinking\": true}'); per-request chat_template_kwargs override these."
73
+ )
74
+ ),
75
+ ] = None,
76
+ device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto",
77
+ dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto",
78
+ trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False,
79
+ model_timeout: Annotated[
80
+ int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.")
81
+ ] = 300,
82
+ # Continuous batching tuning
83
+ cb_block_size: Annotated[
84
+ int | None, typer.Option(help="KV cache block size in tokens for continuous batching.")
85
+ ] = None,
86
+ cb_num_blocks: Annotated[
87
+ int | None, typer.Option(help="Number of KV cache blocks for continuous batching.")
88
+ ] = None,
89
+ cb_max_batch_tokens: Annotated[
90
+ int | None, typer.Option(help="Maximum tokens per batch for continuous batching.")
91
+ ] = None,
92
+ cb_max_memory_percent: Annotated[
93
+ float | None, typer.Option(help="Max GPU memory fraction for KV cache (0.0-1.0).")
94
+ ] = None,
95
+ cb_use_cuda_graph: Annotated[
96
+ bool | None, typer.Option(help="Enable CUDA graphs for continuous batching.")
97
+ ] = None,
98
+ # Server options
99
+ host: Annotated[str, typer.Option(help="Server listen address.")] = "localhost",
100
+ port: Annotated[int, typer.Option(help="Server listen port.")] = 8000,
101
+ enable_cors: Annotated[bool, typer.Option(help="Enable permissive CORS.")] = False,
102
+ log_level: Annotated[str, typer.Option(help="Logging level (e.g. 'info', 'warning').")] = "warning",
103
+ default_seed: Annotated[int | None, typer.Option(help="Default torch seed.")] = None,
104
+ non_blocking: Annotated[
105
+ bool, typer.Option(hidden=True, help="Run server in a background thread. Used by tests.")
106
+ ] = False,
107
+ ) -> None:
108
+ if not is_serve_available():
109
+ raise ImportError("Missing dependencies for serving. Install with `pip install transformers[serving]`")
110
+
111
+ import uvicorn
112
+
113
+ from .serving.chat_completion import ChatCompletionHandler
114
+ from .serving.completion import CompletionHandler
115
+ from .serving.model_manager import ModelManager
116
+ from .serving.response import ResponseHandler
117
+ from .serving.server import build_server
118
+ from .serving.transcription import TranscriptionHandler
119
+ from .serving.utils import GenerationState
120
+
121
+ # Seed
122
+ if default_seed is not None:
123
+ set_torch_seed(default_seed)
124
+
125
+ # Logging
126
+ transformers_logger = logging.get_logger("transformers")
127
+ transformers_logger.setLevel(logging.log_levels[log_level.lower()])
128
+
129
+ self._model_manager = ModelManager(
130
+ device=device,
131
+ dtype=dtype,
132
+ trust_remote_code=trust_remote_code,
133
+ attn_implementation=attn_implementation,
134
+ quantization=quantization,
135
+ model_timeout=model_timeout,
136
+ force_model=force_model,
137
+ )
138
+ from transformers import ContinuousBatchingConfig
139
+
140
+ cb_kwargs = {
141
+ k: v
142
+ for k, v in {
143
+ "block_size": cb_block_size,
144
+ "num_blocks": cb_num_blocks,
145
+ "max_batch_tokens": cb_max_batch_tokens,
146
+ "max_memory_percent": cb_max_memory_percent,
147
+ "use_cuda_graph": cb_use_cuda_graph,
148
+ }.items()
149
+ if v is not None
150
+ }
151
+ cb_config = ContinuousBatchingConfig(**cb_kwargs) if cb_kwargs else None
152
+ self._generation_state = GenerationState(
153
+ continuous_batching=continuous_batching,
154
+ compile=compile,
155
+ cb_config=cb_config,
156
+ )
157
+
158
+ if chat_template_kwargs:
159
+ chat_template_kwargs = json.loads(chat_template_kwargs)
160
+ if not isinstance(chat_template_kwargs, dict):
161
+ raise typer.BadParameter("--chat-template-kwargs must be a JSON object")
162
+ else:
163
+ chat_template_kwargs = {}
164
+
165
+ if reasoning == ReasoningMode.ON:
166
+ chat_template_kwargs["enable_thinking"] = True
167
+ elif reasoning == ReasoningMode.OFF:
168
+ chat_template_kwargs["enable_thinking"] = False
169
+
170
+ self._chat_handler = ChatCompletionHandler(
171
+ model_manager=self._model_manager,
172
+ generation_state=self._generation_state,
173
+ chat_template_kwargs=chat_template_kwargs,
174
+ )
175
+
176
+ self._completion_handler = CompletionHandler(
177
+ model_manager=self._model_manager,
178
+ generation_state=self._generation_state,
179
+ )
180
+
181
+ self._response_handler = ResponseHandler(
182
+ model_manager=self._model_manager,
183
+ generation_state=self._generation_state,
184
+ chat_template_kwargs=chat_template_kwargs,
185
+ )
186
+
187
+ self._transcription_handler = TranscriptionHandler(self._model_manager, self._generation_state)
188
+
189
+ app = build_server(
190
+ self._model_manager,
191
+ self._chat_handler,
192
+ completion_handler=self._completion_handler,
193
+ response_handler=self._response_handler,
194
+ transcription_handler=self._transcription_handler,
195
+ generation_state=self._generation_state,
196
+ enable_cors=enable_cors,
197
+ )
198
+
199
+ config = uvicorn.Config(app, host=host, port=port, log_level="info")
200
+ self.server = uvicorn.Server(config)
201
+
202
+ if non_blocking:
203
+ self.start_server()
204
+ else:
205
+ self.server.run()
206
+
207
+ def start_server(self):
208
+ def _run():
209
+ loop = asyncio.new_event_loop()
210
+ asyncio.set_event_loop(loop)
211
+ loop.run_until_complete(self.server.serve())
212
+
213
+ self._thread = threading.Thread(target=_run, name="uvicorn-thread", daemon=False)
214
+ self._thread.start()
215
+
216
+ def reset_loaded_models(self):
217
+ """Clear all loaded models from memory."""
218
+ self._model_manager.shutdown()
219
+
220
+ def kill_server(self):
221
+ self._generation_state.shutdown()
222
+ self._model_manager.shutdown()
223
+ if not self._thread or not self._thread.is_alive():
224
+ return
225
+ self.server.should_exit = True
226
+ self._thread.join(timeout=2)
227
+
228
+
229
+ Serve.__doc__ = """
230
+ Run a FastAPI server to serve models on-demand with an OpenAI compatible API.
231
+ Models will be loaded and unloaded automatically based on usage and a timeout.
232
+
233
+ \b
234
+ Endpoints:
235
+ POST /v1/chat/completions — Chat completions (streaming + non-streaming).
236
+ POST /v1/completions — Legacy text completions from a prompt.
237
+ GET /v1/models — Lists available models.
238
+ GET /health — Health check.
239
+
240
+ Requires FastAPI and Uvicorn: pip install transformers[serving]
241
+ """
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/system.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains commands to print information about the environment and version.
15
+
16
+ Usage:
17
+ transformers env
18
+ transformers version
19
+ """
20
+
21
+ import contextlib
22
+ import io
23
+ import os
24
+ import platform
25
+ from typing import Annotated
26
+
27
+ import huggingface_hub
28
+ import typer
29
+
30
+ from .. import __version__
31
+ from ..integrations.deepspeed import is_deepspeed_available
32
+ from ..utils import (
33
+ is_accelerate_available,
34
+ is_torch_available,
35
+ is_torch_hpu_available,
36
+ is_torch_npu_available,
37
+ is_torch_xpu_available,
38
+ )
39
+
40
+
41
+ def env(
42
+ accelerate_config_file: Annotated[
43
+ str | None,
44
+ typer.Argument(help="The accelerate config file to use for the default values in the launching script."),
45
+ ] = None,
46
+ ) -> None:
47
+ """Print information about the environment."""
48
+ import safetensors
49
+
50
+ # TODO: remove hasattr guard once safetensors >= 0.8.0 is released (adds __version__)
51
+ safetensors_version = safetensors.__version__ if hasattr(safetensors, "__version__") else "unknown"
52
+
53
+ accelerate_version = "not installed"
54
+ accelerate_config = accelerate_config_str = "not found"
55
+
56
+ if is_accelerate_available():
57
+ import accelerate
58
+ from accelerate.commands.config import default_config_file, load_config_from_file
59
+
60
+ accelerate_version = accelerate.__version__
61
+ # Get the default from the config file.
62
+ if accelerate_config_file is not None or os.path.isfile(default_config_file):
63
+ accelerate_config = load_config_from_file(accelerate_config_file).to_dict()
64
+
65
+ accelerate_config_str = (
66
+ "\n".join([f"\t- {prop}: {val}" for prop, val in accelerate_config.items()])
67
+ if isinstance(accelerate_config, dict)
68
+ else f"\t{accelerate_config}"
69
+ )
70
+
71
+ pt_version = "not installed"
72
+ pt_cuda_available = "NA"
73
+ pt_accelerator = "NA"
74
+ if is_torch_available():
75
+ import torch
76
+
77
+ pt_version = torch.__version__
78
+ pt_cuda_available = torch.cuda.is_available()
79
+ pt_xpu_available = is_torch_xpu_available()
80
+ pt_npu_available = is_torch_npu_available()
81
+ pt_hpu_available = is_torch_hpu_available()
82
+
83
+ if pt_cuda_available:
84
+ pt_accelerator = "CUDA"
85
+ elif pt_xpu_available:
86
+ pt_accelerator = "XPU"
87
+ elif pt_npu_available:
88
+ pt_accelerator = "NPU"
89
+ elif pt_hpu_available:
90
+ pt_accelerator = "HPU"
91
+
92
+ deepspeed_version = "not installed"
93
+ if is_deepspeed_available():
94
+ # Redirect command line output to silence deepspeed import output.
95
+ with contextlib.redirect_stdout(io.StringIO()):
96
+ import deepspeed
97
+ deepspeed_version = deepspeed.__version__
98
+
99
+ info = {
100
+ "`transformers` version": __version__,
101
+ "Platform": platform.platform(),
102
+ "Python version": platform.python_version(),
103
+ "Huggingface_hub version": huggingface_hub.__version__,
104
+ "Safetensors version": f"{safetensors_version}",
105
+ "Accelerate version": f"{accelerate_version}",
106
+ "Accelerate config": f"{accelerate_config_str}",
107
+ "DeepSpeed version": f"{deepspeed_version}",
108
+ "PyTorch version (accelerator?)": f"{pt_version} ({pt_accelerator})",
109
+ "Using distributed or parallel set-up in script?": "<fill in>",
110
+ }
111
+ if is_torch_available():
112
+ if pt_cuda_available:
113
+ info["Using GPU in script?"] = "<fill in>"
114
+ info["GPU type"] = torch.cuda.get_device_name()
115
+ elif pt_xpu_available:
116
+ info["Using XPU in script?"] = "<fill in>"
117
+ info["XPU type"] = torch.xpu.get_device_name()
118
+ elif pt_hpu_available and hasattr(torch, "hpu"):
119
+ info["Using HPU in script?"] = "<fill in>"
120
+ info["HPU type"] = torch.hpu.get_device_name()
121
+ elif pt_npu_available and hasattr(torch, "npu"):
122
+ info["Using NPU in script?"] = "<fill in>"
123
+ info["NPU type"] = torch.npu.get_device_name()
124
+ if hasattr(torch.version, "cann"):
125
+ info["CANN version"] = torch.version.cann
126
+
127
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
128
+ print(_format_dict(info))
129
+
130
+ return info
131
+
132
+
133
+ def version() -> None:
134
+ """Print CLI version."""
135
+ print(__version__)
136
+
137
+
138
+ def _format_dict(d: dict) -> str:
139
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cli/transformers.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Transformers CLI."""
15
+
16
+ from huggingface_hub import check_cli_update, typer_factory
17
+
18
+ from transformers.cli.add_new_model_like import add_new_model_like
19
+ from transformers.cli.chat import Chat
20
+ from transformers.cli.download import download
21
+ from transformers.cli.serve import Serve
22
+ from transformers.cli.system import env, version
23
+
24
+
25
+ app = typer_factory(help="Transformers CLI")
26
+
27
+ app.command()(add_new_model_like)
28
+ app.command(name="chat")(Chat)
29
+ app.command()(download)
30
+ app.command()(env)
31
+ app.command(name="serve")(Serve)
32
+ app.command()(version)
33
+
34
+
35
+ def main():
36
+ check_cli_update("transformers")
37
+ app()
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/distributed/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from ..utils import _LazyModule
18
+
19
+
20
+ _import_structure = {
21
+ "configuration_utils": ["DistributedConfig"],
22
+ }
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from .configuration_utils import (
27
+ DistributedConfig,
28
+ )
29
+
30
+ else:
31
+ import sys
32
+
33
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/distributed/configuration_utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import json
17
+ import os
18
+ from dataclasses import dataclass
19
+ from typing import Any
20
+
21
+
22
+ @dataclass
23
+ class DistributedConfig:
24
+ """
25
+ Base class for distributed configs
26
+ """
27
+
28
+ enable_expert_parallel: bool = False
29
+ # TODO: add tp_plan, pp_plan, device_mesh etc..
30
+
31
+ @classmethod
32
+ def from_dict(cls, config_dict, **kwargs):
33
+ """
34
+ Constructs a DistributedConfig instance from a dictionary of parameters.
35
+ Args:
36
+ config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
37
+ **kwargs: Additional keyword arguments to override dictionary values.
38
+ Returns:
39
+ DistributedConfig: Instance of DistributedConfig constructed from the dictionary.
40
+ """
41
+ config = cls(**config_dict)
42
+ to_remove = []
43
+ for key, value in kwargs.items():
44
+ if hasattr(config, key):
45
+ setattr(config, key, value)
46
+ to_remove.append(key)
47
+ for key in to_remove:
48
+ kwargs.pop(key, None)
49
+ return config
50
+
51
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
52
+ def to_json_file(self, json_file_path: str | os.PathLike):
53
+ """
54
+ Save this instance to a JSON file.
55
+ Args:
56
+ json_file_path (`str` or `os.PathLike`):
57
+ Path to the JSON file in which this configuration instance's parameters will be saved.
58
+ use_diff (`bool`, *optional*, defaults to `True`):
59
+ If set to `True`, only the difference between the config instance and the default
60
+ `QuantizationConfig()` is serialized to JSON file.
61
+ """
62
+ with open(json_file_path, "w", encoding="utf-8") as writer:
63
+ config_dict = self.to_dict()
64
+ json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
65
+
66
+ writer.write(json_string)
67
+
68
+ def to_dict(self) -> dict[str, Any]:
69
+ """
70
+ Serializes this instance to a Python dictionary. Returns:
71
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
72
+ """
73
+ return copy.deepcopy(self.__dict__)
74
+
75
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
76
+ def __iter__(self):
77
+ """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
78
+ yield from copy.deepcopy(self.__dict__).items()
79
+
80
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
81
+ def __repr__(self):
82
+ return f"{self.__class__.__name__} {self.to_json_string()}"
83
+
84
+ def to_json_string(self):
85
+ """
86
+ Serializes this instance to a JSON formatted string.
87
+ Returns:
88
+ str: JSON formatted string representing the configuration instance.
89
+ """
90
+ return json.dumps(self.__dict__, indent=2) + "\n"
91
+
92
+ def update(self, **kwargs):
93
+ """
94
+ Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
95
+ returning all the unused kwargs.
96
+ Args:
97
+ kwargs (`Dict[str, Any]`):
98
+ Dictionary of attributes to tentatively update this class.
99
+ Returns:
100
+ `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
101
+ """
102
+ to_remove = []
103
+ for key, value in kwargs.items():
104
+ if hasattr(self, key):
105
+ setattr(self, key, value)
106
+ to_remove.append(key)
107
+
108
+ # Remove all the attributes that were updated, without modifying the input dict
109
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
110
+ return unused_kwargs
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/hyperparameter_search.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .integrations import (
16
+ is_optuna_available,
17
+ is_ray_tune_available,
18
+ is_wandb_available,
19
+ run_hp_search_optuna,
20
+ run_hp_search_ray,
21
+ run_hp_search_wandb,
22
+ )
23
+ from .trainer_utils import (
24
+ HPSearchBackend,
25
+ default_hp_space_optuna,
26
+ default_hp_space_ray,
27
+ default_hp_space_wandb,
28
+ )
29
+ from .utils import logging
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class HyperParamSearchBackendBase:
36
+ name: str
37
+ pip_package: str | None = None
38
+
39
+ @staticmethod
40
+ def is_available():
41
+ raise NotImplementedError
42
+
43
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
44
+ raise NotImplementedError
45
+
46
+ def default_hp_space(self, trial):
47
+ raise NotImplementedError
48
+
49
+ def ensure_available(self):
50
+ if not self.is_available():
51
+ raise RuntimeError(
52
+ f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
53
+ )
54
+
55
+ @classmethod
56
+ def pip_install(cls):
57
+ return f"`pip install {cls.pip_package or cls.name}`"
58
+
59
+
60
+ class OptunaBackend(HyperParamSearchBackendBase):
61
+ name = "optuna"
62
+
63
+ @staticmethod
64
+ def is_available():
65
+ return is_optuna_available()
66
+
67
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
68
+ return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
69
+
70
+ def default_hp_space(self, trial):
71
+ return default_hp_space_optuna(trial)
72
+
73
+
74
+ class RayTuneBackend(HyperParamSearchBackendBase):
75
+ name = "ray"
76
+ pip_package = "'ray[tune]'"
77
+
78
+ @staticmethod
79
+ def is_available():
80
+ return is_ray_tune_available()
81
+
82
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
83
+ return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
84
+
85
+ def default_hp_space(self, trial):
86
+ return default_hp_space_ray(trial)
87
+
88
+
89
+ class WandbBackend(HyperParamSearchBackendBase):
90
+ name = "wandb"
91
+
92
+ @staticmethod
93
+ def is_available():
94
+ return is_wandb_available()
95
+
96
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
97
+ return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
98
+
99
+ def default_hp_space(self, trial):
100
+ return default_hp_space_wandb(trial)
101
+
102
+
103
+ ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
104
+ HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, WandbBackend]
105
+ }
106
+
107
+
108
+ def default_hp_search_backend() -> str:
109
+ available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
110
+ if len(available_backends) > 0:
111
+ name = available_backends[0].name
112
+ if len(available_backends) > 1:
113
+ logger.info(
114
+ f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
115
+ )
116
+ return name
117
+ raise RuntimeError(
118
+ "No hyperparameter search backend available.\n"
119
+ + "\n".join(
120
+ f" - To install {backend.name} run {backend.pip_install()}"
121
+ for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
122
+ )
123
+ )
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_transforms.py ADDED
@@ -0,0 +1,1073 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import defaultdict
16
+ from collections.abc import Collection, Iterable
17
+ from math import ceil
18
+ from typing import Optional, Union
19
+
20
+ import numpy as np
21
+
22
+ from .image_utils import (
23
+ ChannelDimension,
24
+ ImageInput,
25
+ get_channel_dimension_axis,
26
+ get_image_size,
27
+ infer_channel_dimension_format,
28
+ )
29
+ from .utils import ExplicitEnum, TensorType, is_torch_tensor
30
+ from .utils.import_utils import (
31
+ is_torch_available,
32
+ is_vision_available,
33
+ requires_backends,
34
+ )
35
+
36
+
37
+ if is_vision_available():
38
+ import PIL
39
+
40
+ from .image_utils import PILImageResampling
41
+
42
+ if is_torch_available():
43
+ import torch
44
+
45
+
46
+ def to_channel_dimension_format(
47
+ image: np.ndarray,
48
+ channel_dim: ChannelDimension | str,
49
+ input_channel_dim: ChannelDimension | str | None = None,
50
+ ) -> np.ndarray:
51
+ """
52
+ Converts `image` to the channel dimension format specified by `channel_dim`. The input
53
+ can have arbitrary number of leading dimensions. Only last three dimension will be permuted
54
+ to format the `image`.
55
+
56
+ Args:
57
+ image (`numpy.ndarray`):
58
+ The image to have its channel dimension set.
59
+ channel_dim (`ChannelDimension`):
60
+ The channel dimension format to use.
61
+ input_channel_dim (`ChannelDimension`, *optional*):
62
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
63
+
64
+ Returns:
65
+ `np.ndarray`: The image with the channel dimension set to `channel_dim`.
66
+ """
67
+ if not isinstance(image, np.ndarray):
68
+ raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
69
+
70
+ if input_channel_dim is None:
71
+ input_channel_dim = infer_channel_dimension_format(image)
72
+
73
+ target_channel_dim = ChannelDimension(channel_dim)
74
+ if input_channel_dim == target_channel_dim:
75
+ return image
76
+
77
+ if target_channel_dim == ChannelDimension.FIRST:
78
+ axes = list(range(image.ndim - 3)) + [image.ndim - 1, image.ndim - 3, image.ndim - 2]
79
+ image = image.transpose(axes)
80
+ elif target_channel_dim == ChannelDimension.LAST:
81
+ axes = list(range(image.ndim - 3)) + [image.ndim - 2, image.ndim - 1, image.ndim - 3]
82
+ image = image.transpose(axes)
83
+ else:
84
+ raise ValueError(f"Unsupported channel dimension format: {channel_dim}")
85
+
86
+ return image
87
+
88
+
89
+ def rescale(
90
+ image: np.ndarray,
91
+ scale: float,
92
+ data_format: ChannelDimension | None = None,
93
+ dtype: np.dtype = np.float32,
94
+ input_data_format: str | ChannelDimension | None = None,
95
+ ) -> np.ndarray:
96
+ """
97
+ Rescales `image` by `scale`.
98
+
99
+ Args:
100
+ image (`np.ndarray`):
101
+ The image to rescale.
102
+ scale (`float`):
103
+ The scale to use for rescaling the image.
104
+ data_format (`ChannelDimension`, *optional*):
105
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
106
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
107
+ The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature
108
+ extractors.
109
+ input_data_format (`ChannelDimension`, *optional*):
110
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
111
+
112
+ Returns:
113
+ `np.ndarray`: The rescaled image.
114
+ """
115
+ if not isinstance(image, np.ndarray):
116
+ raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
117
+
118
+ rescaled_image = image.astype(np.float64) * scale # Numpy type promotion has changed, so always upcast first
119
+ if data_format is not None:
120
+ rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
121
+
122
+ rescaled_image = rescaled_image.astype(dtype) # Finally downcast to the desired dtype at the end
123
+
124
+ return rescaled_image
125
+
126
+
127
+ def _rescale_for_pil_conversion(image):
128
+ """
129
+ Detects whether or not the image needs to be rescaled before being converted to a PIL image.
130
+
131
+ The assumption is that if the image is of type `np.float` and all values are between 0 and 1, it needs to be
132
+ rescaled.
133
+ """
134
+ if image.dtype == np.uint8:
135
+ do_rescale = False
136
+ elif np.allclose(image, image.astype(int)):
137
+ if np.all(image >= 0) and np.all(image <= 255):
138
+ do_rescale = False
139
+ else:
140
+ raise ValueError(
141
+ "The image to be converted to a PIL image contains values outside the range [0, 255], "
142
+ f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
143
+ )
144
+ elif np.all(image >= 0) and np.all(image <= 1):
145
+ do_rescale = True
146
+ else:
147
+ raise ValueError(
148
+ "The image to be converted to a PIL image contains values outside the range [0, 1], "
149
+ f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
150
+ )
151
+ return do_rescale
152
+
153
+
154
+ def to_pil_image(
155
+ image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor"],
156
+ do_rescale: bool | None = None,
157
+ image_mode: str | None = None,
158
+ input_data_format: str | ChannelDimension | None = None,
159
+ ) -> "PIL.Image.Image":
160
+ """
161
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
162
+ needed.
163
+
164
+ Args:
165
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
166
+ The image to convert to the `PIL.Image` format.
167
+ do_rescale (`bool`, *optional*):
168
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
169
+ to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
170
+ and `False` otherwise.
171
+ image_mode (`str`, *optional*):
172
+ The mode to use for the PIL image. If unset, will use the default mode for the input image type.
173
+ input_data_format (`ChannelDimension`, *optional*):
174
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
175
+
176
+ Returns:
177
+ `PIL.Image.Image`: The converted image.
178
+ """
179
+ requires_backends(to_pil_image, ["vision"])
180
+
181
+ if isinstance(image, PIL.Image.Image):
182
+ return image
183
+
184
+ # Convert all tensors to numpy arrays before converting to PIL image
185
+ if is_torch_tensor(image):
186
+ image = image.numpy()
187
+ elif not isinstance(image, np.ndarray):
188
+ raise ValueError(f"Input image type not supported: {type(image)}")
189
+
190
+ # If the channel has been moved to first dim, we put it back at the end.
191
+ image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
192
+
193
+ # If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
194
+ image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
195
+
196
+ # PIL.Image can only store uint8 values so we rescale the image to be between 0 and 255 if needed.
197
+ do_rescale = _rescale_for_pil_conversion(image) if do_rescale is None else do_rescale
198
+
199
+ if do_rescale:
200
+ image = rescale(image, 255)
201
+
202
+ image = image.astype(np.uint8)
203
+ return PIL.Image.fromarray(image, mode=image_mode)
204
+
205
+
206
+ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]:
207
+ """
208
+ Computes the output image size given the input image size and the desired output size.
209
+
210
+ Args:
211
+ image_size (`tuple[int, int]`):
212
+ The input image size.
213
+ size (`int`):
214
+ The desired output size.
215
+ max_size (`int`, *optional*):
216
+ The maximum allowed output size.
217
+ """
218
+ height, width = image_size
219
+ raw_size = None
220
+ if max_size is not None:
221
+ min_original_size = float(min((height, width)))
222
+ max_original_size = float(max((height, width)))
223
+ if max_original_size / min_original_size * size > max_size:
224
+ raw_size = max_size * min_original_size / max_original_size
225
+ size = int(round(raw_size))
226
+
227
+ if (height <= width and height == size) or (width <= height and width == size):
228
+ oh, ow = height, width
229
+ elif width < height:
230
+ ow = size
231
+ if max_size is not None and raw_size is not None:
232
+ oh = int(raw_size * height / width)
233
+ else:
234
+ oh = int(size * height / width)
235
+ else:
236
+ oh = size
237
+ if max_size is not None and raw_size is not None:
238
+ ow = int(raw_size * width / height)
239
+ else:
240
+ ow = int(size * width / height)
241
+
242
+ return (oh, ow)
243
+
244
+
245
+ # Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366
246
+ def get_resize_output_image_size(
247
+ input_image: np.ndarray,
248
+ size: int | tuple[int, int] | list[int] | tuple[int, ...],
249
+ default_to_square: bool = True,
250
+ max_size: int | None = None,
251
+ input_data_format: str | ChannelDimension | None = None,
252
+ ) -> tuple:
253
+ """
254
+ Find the target (height, width) dimension of the output image after resizing given the input image and the desired
255
+ size.
256
+
257
+ Args:
258
+ input_image (`np.ndarray`):
259
+ The image to resize.
260
+ size (`int` or `tuple[int, int]` or list[int] or `tuple[int]`):
261
+ The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be matched to
262
+ this.
263
+
264
+ If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
265
+ `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to this
266
+ number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
267
+ default_to_square (`bool`, *optional*, defaults to `True`):
268
+ How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a square
269
+ (`size`,`size`). If set to `False`, will replicate
270
+ [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
271
+ with support for resizing only the smallest edge and providing an optional `max_size`.
272
+ max_size (`int`, *optional*):
273
+ The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater
274
+ than `max_size` after being resized according to `size`, then the image is resized again so that the longer
275
+ edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
276
+ than `size`. Only used if `default_to_square` is `False`.
277
+ input_data_format (`ChannelDimension`, *optional*):
278
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
279
+
280
+ Returns:
281
+ `tuple`: The target (height, width) dimension of the output image after resizing.
282
+ """
283
+ if isinstance(size, (tuple, list)):
284
+ if len(size) == 2:
285
+ return tuple(size)
286
+ elif len(size) == 1:
287
+ # Perform same logic as if size was an int
288
+ size = size[0]
289
+ else:
290
+ raise ValueError("size must have 1 or 2 elements if it is a list or tuple")
291
+
292
+ if default_to_square:
293
+ return (size, size)
294
+
295
+ height, width = get_image_size(input_image, input_data_format)
296
+ short, long = (width, height) if width <= height else (height, width)
297
+ requested_new_short = size
298
+
299
+ new_short, new_long = requested_new_short, int(requested_new_short * long / short)
300
+
301
+ if max_size is not None:
302
+ if max_size <= requested_new_short:
303
+ raise ValueError(
304
+ f"max_size = {max_size} must be strictly greater than the requested "
305
+ f"size for the smaller edge size = {size}"
306
+ )
307
+ if new_long > max_size:
308
+ new_short, new_long = int(max_size * new_short / new_long), max_size
309
+
310
+ return (new_long, new_short) if width <= height else (new_short, new_long)
311
+
312
+
313
+ def resize(
314
+ image: np.ndarray,
315
+ size: tuple[int, int],
316
+ resample: Optional["PILImageResampling"] = None,
317
+ reducing_gap: int | None = None,
318
+ data_format: ChannelDimension | None = None,
319
+ return_numpy: bool = True,
320
+ input_data_format: str | ChannelDimension | None = None,
321
+ ) -> np.ndarray:
322
+ """
323
+ Resizes `image` to `(height, width)` specified by `size` using the PIL library.
324
+
325
+ Args:
326
+ image (`np.ndarray`):
327
+ The image to resize.
328
+ size (`tuple[int, int]`):
329
+ The size to use for resizing the image.
330
+ resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
331
+ The filter to user for resampling.
332
+ reducing_gap (`int`, *optional*):
333
+ Apply optimization by resizing the image in two steps. The bigger `reducing_gap`, the closer the result to
334
+ the fair resampling. See corresponding Pillow documentation for more details.
335
+ data_format (`ChannelDimension`, *optional*):
336
+ The channel dimension format of the output image. If unset, will use the inferred format from the input.
337
+ return_numpy (`bool`, *optional*, defaults to `True`):
338
+ Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
339
+ returned.
340
+ input_data_format (`ChannelDimension`, *optional*):
341
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
342
+
343
+ Returns:
344
+ `np.ndarray`: The resized image.
345
+ """
346
+ requires_backends(resize, ["vision"])
347
+
348
+ resample = resample if resample is not None else PILImageResampling.BILINEAR
349
+
350
+ if not len(size) == 2:
351
+ raise ValueError("size must have 2 elements")
352
+
353
+ # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
354
+ # The resized image from PIL will always have channels last, so find the input format first.
355
+ if input_data_format is None:
356
+ input_data_format = infer_channel_dimension_format(image)
357
+ data_format = input_data_format if data_format is None else data_format
358
+
359
+ # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
360
+ # the pillow library to resize the image and then convert back to numpy
361
+ do_rescale = False
362
+ if not isinstance(image, PIL.Image.Image):
363
+ do_rescale = _rescale_for_pil_conversion(image)
364
+ image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
365
+ height, width = size
366
+ # PIL images are in the format (width, height)
367
+ resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
368
+
369
+ if return_numpy:
370
+ resized_image = np.array(resized_image)
371
+ # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
372
+ # so we need to add it back if necessary.
373
+ resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
374
+ # The image is always in channels last format after converting from a PIL image
375
+ resized_image = to_channel_dimension_format(
376
+ resized_image, data_format, input_channel_dim=ChannelDimension.LAST
377
+ )
378
+ # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
379
+ # rescale it back to the original range.
380
+ resized_image = rescale(resized_image, 1 / 255) if do_rescale else resized_image
381
+ return resized_image
382
+
383
+
384
+ def normalize(
385
+ image: np.ndarray,
386
+ mean: float | Collection[float],
387
+ std: float | Collection[float],
388
+ data_format: ChannelDimension | None = None,
389
+ input_data_format: str | ChannelDimension | None = None,
390
+ ) -> np.ndarray:
391
+ """
392
+ Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
393
+
394
+ image = (image - mean) / std
395
+
396
+ Args:
397
+ image (`np.ndarray`):
398
+ The image to normalize.
399
+ mean (`float` or `Collection[float]`):
400
+ The mean to use for normalization.
401
+ std (`float` or `Collection[float]`):
402
+ The standard deviation to use for normalization.
403
+ data_format (`ChannelDimension`, *optional*):
404
+ The channel dimension format of the output image. If unset, will use the inferred format from the input.
405
+ input_data_format (`ChannelDimension`, *optional*):
406
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
407
+ """
408
+ if not isinstance(image, np.ndarray):
409
+ raise TypeError("image must be a numpy array")
410
+
411
+ if input_data_format is None:
412
+ input_data_format = infer_channel_dimension_format(image)
413
+
414
+ channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
415
+ num_channels = image.shape[channel_axis]
416
+
417
+ # We cast to float32 to avoid errors that can occur when subtracting uint8 values.
418
+ # We preserve the original dtype if it is a float type to prevent upcasting float16.
419
+ if not np.issubdtype(image.dtype, np.floating):
420
+ image = image.astype(np.float32)
421
+
422
+ if isinstance(mean, Collection):
423
+ if len(mean) != num_channels:
424
+ raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
425
+ else:
426
+ mean = [mean] * num_channels
427
+ mean = np.array(mean, dtype=image.dtype)
428
+
429
+ if isinstance(std, Collection):
430
+ if len(std) != num_channels:
431
+ raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
432
+ else:
433
+ std = [std] * num_channels
434
+ std = np.array(std, dtype=image.dtype)
435
+
436
+ if input_data_format == ChannelDimension.LAST:
437
+ image = (image - mean) / std
438
+ else:
439
+ image = ((image.T - mean) / std).T
440
+
441
+ image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
442
+ return image
443
+
444
+
445
+ def center_crop(
446
+ image: np.ndarray,
447
+ size: tuple[int, int],
448
+ data_format: str | ChannelDimension | None = None,
449
+ input_data_format: str | ChannelDimension | None = None,
450
+ ) -> np.ndarray:
451
+ """
452
+ Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
453
+ the size given, it will be padded (so the returned result will always be of size `size`).
454
+
455
+ Args:
456
+ image (`np.ndarray`):
457
+ The image to crop.
458
+ size (`tuple[int, int]`):
459
+ The target size for the cropped image.
460
+ data_format (`str` or `ChannelDimension`, *optional*):
461
+ The channel dimension format for the output image. Can be one of:
462
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
463
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
464
+ If unset, will use the inferred format of the input image.
465
+ input_data_format (`str` or `ChannelDimension`, *optional*):
466
+ The channel dimension format for the input image. Can be one of:
467
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
468
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
469
+ If unset, will use the inferred format of the input image.
470
+ Returns:
471
+ `np.ndarray`: The cropped image.
472
+ """
473
+ requires_backends(center_crop, ["vision"])
474
+
475
+ if not isinstance(image, np.ndarray):
476
+ raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
477
+
478
+ if not isinstance(size, Iterable) or len(size) != 2:
479
+ raise ValueError("size must have 2 elements representing the height and width of the output image")
480
+
481
+ if input_data_format is None:
482
+ input_data_format = infer_channel_dimension_format(image)
483
+ output_data_format = data_format if data_format is not None else input_data_format
484
+
485
+ # We perform the crop in (C, H, W) format and then convert to the output format
486
+ image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
487
+
488
+ orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
489
+ crop_height, crop_width = size
490
+ crop_height, crop_width = int(crop_height), int(crop_width)
491
+
492
+ # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
493
+ top = (orig_height - crop_height) // 2
494
+ bottom = top + crop_height
495
+ # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
496
+ left = (orig_width - crop_width) // 2
497
+ right = left + crop_width
498
+
499
+ # Check if cropped area is within image boundaries
500
+ if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
501
+ image = image[..., top:bottom, left:right]
502
+ image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
503
+ return image
504
+
505
+ # Otherwise, we may need to pad if the image is too small. Oh joy...
506
+ new_height = max(crop_height, orig_height)
507
+ new_width = max(crop_width, orig_width)
508
+ new_shape = image.shape[:-2] + (new_height, new_width)
509
+ new_image = np.zeros_like(image, shape=new_shape)
510
+
511
+ # If the image is too small, pad it with zeros
512
+ top_pad = ceil((new_height - orig_height) / 2)
513
+ bottom_pad = top_pad + orig_height
514
+ left_pad = ceil((new_width - orig_width) / 2)
515
+ right_pad = left_pad + orig_width
516
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
517
+
518
+ top += top_pad
519
+ bottom += top_pad
520
+ left += left_pad
521
+ right += left_pad
522
+
523
+ new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
524
+ new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
525
+
526
+ return new_image
527
+
528
+
529
+ def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor":
530
+ center_x, center_y, width, height = bboxes_center.unbind(-1)
531
+ bbox_corners = torch.stack(
532
+ # top left x, top left y, bottom right x, bottom right y
533
+ [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)],
534
+ dim=-1,
535
+ )
536
+ return bbox_corners
537
+
538
+
539
+ def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray:
540
+ center_x, center_y, width, height = bboxes_center.T
541
+ bboxes_corners = np.stack(
542
+ # top left x, top left y, bottom right x, bottom right y
543
+ [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
544
+ axis=-1,
545
+ )
546
+ return bboxes_corners
547
+
548
+
549
+ # 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
550
+ def center_to_corners_format(bboxes_center: TensorType) -> TensorType:
551
+ """
552
+ Converts bounding boxes from center format to corners format.
553
+
554
+ center format: contains the coordinate for the center of the box and its width, height dimensions
555
+ (center_x, center_y, width, height)
556
+ corners format: contains the coordinates for the top-left and bottom-right corners of the box
557
+ (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
558
+ """
559
+ # Function is used during model forward pass, so we use torch if relevant, without converting to numpy
560
+ if is_torch_tensor(bboxes_center):
561
+ return _center_to_corners_format_torch(bboxes_center)
562
+ elif isinstance(bboxes_center, np.ndarray):
563
+ return _center_to_corners_format_numpy(bboxes_center)
564
+
565
+ raise ValueError(f"Unsupported input type {type(bboxes_center)}")
566
+
567
+
568
+ def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor":
569
+ top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1)
570
+ b = [
571
+ (top_left_x + bottom_right_x) / 2, # center x
572
+ (top_left_y + bottom_right_y) / 2, # center y
573
+ (bottom_right_x - top_left_x), # width
574
+ (bottom_right_y - top_left_y), # height
575
+ ]
576
+ return torch.stack(b, dim=-1)
577
+
578
+
579
+ def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray:
580
+ top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T
581
+ bboxes_center = np.stack(
582
+ [
583
+ (top_left_x + bottom_right_x) / 2, # center x
584
+ (top_left_y + bottom_right_y) / 2, # center y
585
+ (bottom_right_x - top_left_x), # width
586
+ (bottom_right_y - top_left_y), # height
587
+ ],
588
+ axis=-1,
589
+ )
590
+ return bboxes_center
591
+
592
+
593
+ def corners_to_center_format(bboxes_corners: TensorType) -> TensorType:
594
+ """
595
+ Converts bounding boxes from corners format to center format.
596
+
597
+ corners format: contains the coordinates for the top-left and bottom-right corners of the box
598
+ (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
599
+ center format: contains the coordinate for the center of the box and its the width, height dimensions
600
+ (center_x, center_y, width, height)
601
+ """
602
+ # Inverse function accepts different input types so implemented here too
603
+ if is_torch_tensor(bboxes_corners):
604
+ return _corners_to_center_format_torch(bboxes_corners)
605
+ elif isinstance(bboxes_corners, np.ndarray):
606
+ return _corners_to_center_format_numpy(bboxes_corners)
607
+
608
+ raise ValueError(f"Unsupported input type {type(bboxes_corners)}")
609
+
610
+
611
+ def safe_squeeze(
612
+ tensor: Union[np.ndarray, "torch.Tensor"], axis: int | None = None
613
+ ) -> Union[np.ndarray, "torch.Tensor"]:
614
+ """
615
+ Squeezes a tensor, but only if the axis specified has dim 1.
616
+ """
617
+ if axis is None:
618
+ return tensor.squeeze()
619
+
620
+ try:
621
+ return tensor.squeeze(axis=axis)
622
+ except ValueError:
623
+ return tensor
624
+
625
+
626
+ # 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
627
+ # Copyright (c) 2018, Alexander Kirillov
628
+ # All rights reserved.
629
+ def rgb_to_id(color):
630
+ """
631
+ Converts RGB color to unique ID.
632
+ """
633
+ if isinstance(color, np.ndarray) and len(color.shape) == 3:
634
+ if color.dtype == np.uint8:
635
+ color = color.astype(np.int32)
636
+ return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
637
+ return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
638
+
639
+
640
+ def id_to_rgb(id_map):
641
+ """
642
+ Converts unique ID to RGB color.
643
+ """
644
+ if isinstance(id_map, np.ndarray):
645
+ id_map_copy = id_map.copy()
646
+ rgb_shape = tuple(list(id_map.shape) + [3])
647
+ rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
648
+ for i in range(3):
649
+ rgb_map[..., i] = id_map_copy % 256
650
+ id_map_copy //= 256
651
+ return rgb_map
652
+ color = []
653
+ for _ in range(3):
654
+ color.append(id_map % 256)
655
+ id_map //= 256
656
+ return color
657
+
658
+
659
+ class PaddingMode(ExplicitEnum):
660
+ """
661
+ Enum class for the different padding modes to use when padding images.
662
+ """
663
+
664
+ CONSTANT = "constant"
665
+ REFLECT = "reflect"
666
+ REPLICATE = "replicate"
667
+ SYMMETRIC = "symmetric"
668
+
669
+
670
+ def pad(
671
+ image: np.ndarray,
672
+ padding: int | tuple[int, int] | Iterable[tuple[int, int]],
673
+ mode: PaddingMode = PaddingMode.CONSTANT,
674
+ constant_values: float | Iterable[float] = 0.0,
675
+ data_format: str | ChannelDimension | None = None,
676
+ input_data_format: str | ChannelDimension | None = None,
677
+ ) -> np.ndarray:
678
+ """
679
+ Pads the `image` with the specified (height, width) `padding` and `mode`.
680
+
681
+ Args:
682
+ image (`np.ndarray`):
683
+ The image to pad.
684
+ padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
685
+ Padding to apply to the edges of the height, width axes. Can be one of three formats:
686
+ - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
687
+ - `((before, after),)` yields same before and after pad for height and width.
688
+ - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
689
+ mode (`PaddingMode`):
690
+ The padding mode to use. Can be one of:
691
+ - `"constant"`: pads with a constant value.
692
+ - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
693
+ vector along each axis.
694
+ - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
695
+ - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
696
+ constant_values (`float` or `Iterable[float]`, *optional*):
697
+ The value to use for the padding if `mode` is `"constant"`.
698
+ data_format (`str` or `ChannelDimension`, *optional*):
699
+ The channel dimension format for the output image. Can be one of:
700
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
701
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
702
+ If unset, will use same as the input image.
703
+ input_data_format (`str` or `ChannelDimension`, *optional*):
704
+ The channel dimension format for the input image. Can be one of:
705
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
706
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
707
+ If unset, will use the inferred format of the input image.
708
+
709
+ Returns:
710
+ `np.ndarray`: The padded image.
711
+
712
+ """
713
+ if input_data_format is None:
714
+ input_data_format = infer_channel_dimension_format(image)
715
+
716
+ def _expand_for_data_format(values):
717
+ """
718
+ Convert values to be in the format expected by np.pad based on the data format.
719
+ """
720
+ if isinstance(values, (int, float)):
721
+ values = ((values, values), (values, values))
722
+ elif isinstance(values, tuple) and len(values) == 1:
723
+ values = ((values[0], values[0]), (values[0], values[0]))
724
+ elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
725
+ values = (values, values)
726
+ elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
727
+ pass
728
+ else:
729
+ raise ValueError(f"Unsupported format: {values}")
730
+
731
+ # add 0 for channel dimension
732
+ values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0))
733
+
734
+ # Add additional padding if there's a batch dimension
735
+ values = ((0, 0), *values) if image.ndim == 4 else values
736
+ return values
737
+
738
+ padding = _expand_for_data_format(padding)
739
+
740
+ if mode == PaddingMode.CONSTANT:
741
+ constant_values = _expand_for_data_format(constant_values)
742
+ image = np.pad(image, padding, mode="constant", constant_values=constant_values)
743
+ elif mode == PaddingMode.REFLECT:
744
+ image = np.pad(image, padding, mode="reflect")
745
+ elif mode == PaddingMode.REPLICATE:
746
+ image = np.pad(image, padding, mode="edge")
747
+ elif mode == PaddingMode.SYMMETRIC:
748
+ image = np.pad(image, padding, mode="symmetric")
749
+ else:
750
+ raise ValueError(f"Invalid padding mode: {mode}")
751
+
752
+ image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
753
+ return image
754
+
755
+
756
+ # TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default
757
+ def convert_to_rgb(image: ImageInput) -> ImageInput:
758
+ """
759
+ Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
760
+ as is.
761
+ Args:
762
+ image (Image):
763
+ The image to convert.
764
+ """
765
+ requires_backends(convert_to_rgb, ["vision"])
766
+
767
+ if not isinstance(image, PIL.Image.Image):
768
+ return image
769
+
770
+ if image.mode == "RGB":
771
+ return image
772
+
773
+ image = image.convert("RGB")
774
+ return image
775
+
776
+
777
+ def flip_channel_order(
778
+ image: np.ndarray,
779
+ data_format: ChannelDimension | None = None,
780
+ input_data_format: str | ChannelDimension | None = None,
781
+ ) -> np.ndarray:
782
+ """
783
+ Flips the channel order of the image.
784
+
785
+ If the image is in RGB format, it will be converted to BGR and vice versa.
786
+
787
+ Args:
788
+ image (`np.ndarray`):
789
+ The image to flip.
790
+ data_format (`ChannelDimension`, *optional*):
791
+ The channel dimension format for the output image. Can be one of:
792
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
793
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
794
+ If unset, will use same as the input image.
795
+ input_data_format (`ChannelDimension`, *optional*):
796
+ The channel dimension format for the input image. Can be one of:
797
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
798
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
799
+ If unset, will use the inferred format of the input image.
800
+ """
801
+ input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format
802
+
803
+ if input_data_format == ChannelDimension.LAST:
804
+ image = image[..., ::-1]
805
+ elif input_data_format == ChannelDimension.FIRST:
806
+ image = image[::-1, ...]
807
+ else:
808
+ raise ValueError(f"Unsupported channel dimension: {input_data_format}")
809
+
810
+ if data_format is not None:
811
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
812
+ return image
813
+
814
+
815
+ def split_to_tiles(images: "torch.Tensor", num_tiles_height: int, num_tiles_width: int) -> "torch.Tensor":
816
+ # Split image into number of required tiles (width x height)
817
+ batch_size, num_channels, height, width = images.size()
818
+ images = images.view(
819
+ batch_size,
820
+ num_channels,
821
+ num_tiles_height,
822
+ height // num_tiles_height,
823
+ num_tiles_width,
824
+ width // num_tiles_width,
825
+ )
826
+ # Permute dimensions to reorder the axes
827
+ image = images.permute(0, 2, 4, 1, 3, 5).contiguous()
828
+ # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
829
+ image = image.view(
830
+ batch_size,
831
+ num_tiles_width * num_tiles_height,
832
+ num_channels,
833
+ height // num_tiles_height,
834
+ width // num_tiles_width,
835
+ )
836
+ return image
837
+
838
+
839
+ def divide_to_patches(
840
+ image: Union[np.ndarray, "torch.Tensor"], patch_size: int | tuple[int, int]
841
+ ) -> list[Union[np.ndarray, "torch.Tensor"]]:
842
+ """
843
+ Divides an image into patches of a specified size.
844
+
845
+ Args:
846
+ image (`np.array | "torch.Tensor"`):
847
+ The input image.
848
+ patch_size (`int` or `tuple[int, int]`):
849
+ The size of each patch. If an int, patches are square. If a tuple,
850
+ it is interpreted as `(patch_height, patch_width)`.
851
+ Returns:
852
+ list: A list of `np.array | "torch.Tensor"` representing the patches.
853
+ """
854
+ patch_h, patch_w = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
855
+ patches = []
856
+ height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
857
+ for i in range(0, height, patch_h):
858
+ for j in range(0, width, patch_w):
859
+ patch = image[..., i : i + patch_h, j : j + patch_w]
860
+ patches.append(patch)
861
+
862
+ return patches
863
+
864
+
865
+ def _group_images_by_shape(nested_images, *paired_inputs, is_nested: bool = False):
866
+ """
867
+ Helper function to flatten a single level of nested image and batch structures and group by shape.
868
+ Args:
869
+ nested_images (list):
870
+ A list of images or a single tensor
871
+ paired_inputs (Any, *optional*):
872
+ Zero or more lists that mirror the structure of `nested_images` (flat list, or list of lists when
873
+ `is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
874
+ same shape key. These paired values are grouped alongside `nested_images` but are not stacked in the output, so
875
+ they do not need to be tensors.
876
+ is_nested (bool, *optional*, defaults to False):
877
+ Whether the images are nested.
878
+ Returns:
879
+ tuple[dict, ...]:
880
+ - A dictionary with shape as key and list of images with that shape as value
881
+ - A dictionary with shape as key and list of paired values with that shape as value
882
+ - A dictionary mapping original indices to (shape, index) tuples
883
+ - A dictionary mapping original indices to (shape, index) tuples for each paired input
884
+ """
885
+ grouped_images = defaultdict(list)
886
+ grouped_images_index = {}
887
+ paired_grouped_values = [defaultdict(list) for _ in paired_inputs]
888
+
889
+ # Normalize inputs to consistent nested structure
890
+ normalized_images = [nested_images] if not is_nested else nested_images
891
+ normalized_paired = []
892
+ for paired_input in paired_inputs:
893
+ normalized_paired.append([paired_input] if not is_nested else paired_input)
894
+
895
+ # Process each image and group by shape
896
+ for i, (sublist, *paired_sublists) in enumerate(zip(normalized_images, *normalized_paired)):
897
+ for j, (image, *paired_values) in enumerate(zip(sublist, *paired_sublists)):
898
+ key = (i, j) if is_nested else j
899
+ shape = image.shape[1:]
900
+
901
+ # Add to grouped structures
902
+ grouped_images[shape].append(image)
903
+ for paired_index, paired_value in enumerate(paired_values):
904
+ paired_grouped_values[paired_index][shape].append(paired_value)
905
+ grouped_images_index[key] = (shape, len(grouped_images[shape]) - 1)
906
+
907
+ # Store structure size for nested inputs to handle empty sublists during reconstruction
908
+ if is_nested:
909
+ grouped_images_index["_num_sublists"] = len(normalized_images)
910
+
911
+ return grouped_images, *paired_grouped_values, grouped_images_index
912
+
913
+
914
+ def _reconstruct_nested_structure(indices, processed_images):
915
+ """Helper function to reconstruct a single level nested structure."""
916
+ # Get the number of sublists (handles empty sublists like in [[], [image]])
917
+ num_sublists = indices.pop("_num_sublists", None)
918
+
919
+ # Group indices by outer index
920
+ nested_indices = defaultdict(list)
921
+ for i, j in indices:
922
+ nested_indices[i].append(j)
923
+
924
+ # Determine the number of outer sublists
925
+ if num_sublists is not None:
926
+ max_outer_idx = num_sublists - 1
927
+ elif nested_indices:
928
+ max_outer_idx = max(nested_indices.keys())
929
+ else:
930
+ return []
931
+
932
+ # Create the result structure
933
+ result = []
934
+ for i in range(max_outer_idx + 1):
935
+ if i not in nested_indices:
936
+ result.append([])
937
+ else:
938
+ inner_max_idx = max(nested_indices[i])
939
+ inner_list = [None] * (inner_max_idx + 1)
940
+ for j in nested_indices[i]:
941
+ shape, idx = indices[(i, j)]
942
+ inner_list[j] = processed_images[shape][idx]
943
+ result.append(inner_list)
944
+
945
+ return result
946
+
947
+
948
+ def _iterate_items(items, is_nested: bool):
949
+ """
950
+ Helper function to iterate over items yielding (key, item) pairs.
951
+
952
+ For nested structures, yields ((row_index, col_index), item).
953
+ For flat structures, yields (index, item).
954
+ """
955
+ if is_nested:
956
+ for i, row in enumerate(items):
957
+ for j, item in enumerate(row):
958
+ yield (i, j), item
959
+ else:
960
+ for i, item in enumerate(items):
961
+ yield i, item
962
+
963
+
964
+ def _get_device_from_images(images, is_nested: bool) -> "torch.device":
965
+ """
966
+ Get the device from the first non-empty element in a (potentially nested) list of images.
967
+
968
+ Handles cases like `images = [[], [image]]` where the first sublist may be empty.
969
+ """
970
+ if is_nested:
971
+ for row in images:
972
+ if isinstance(row, torch.Tensor):
973
+ return row.device
974
+ if isinstance(row, list) and len(row) > 0:
975
+ return row[0].device
976
+ return images[0].device
977
+
978
+
979
+ def group_images_by_shape(
980
+ images: Union[list["torch.Tensor"], "torch.Tensor"],
981
+ *paired_inputs,
982
+ disable_grouping: bool | None,
983
+ is_nested: bool = False,
984
+ ) -> tuple[dict, ...]:
985
+ """
986
+ Groups images by shape.
987
+ Returns a dictionary with the shape as key and a list of images with that shape as value,
988
+ and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value.
989
+
990
+ The function supports both flat lists of tensors and nested structures.
991
+ The input must be either all flat or all nested, not a mix of both.
992
+
993
+ Args:
994
+ images (Union[list["torch.Tensor"], "torch.Tensor"]):
995
+ A list of images or a single tensor
996
+ paired_inputs (Any, *optional*):
997
+ Zero or more lists that mirror the structure of `images` (flat list, or list of lists when
998
+ `is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
999
+ same shape key. These paired values are grouped alongside `images` but are not stacked in the output, so
1000
+ they do not need to be tensors.
1001
+ disable_grouping (bool):
1002
+ Whether to disable grouping. If None, will be set to True if the images are on CPU, and False otherwise.
1003
+ This choice is based on empirical observations, as detailed here: https://github.com/huggingface/transformers/pull/38157
1004
+ is_nested (bool, *optional*, defaults to False):
1005
+ Whether the images are nested.
1006
+
1007
+ Returns:
1008
+ tuple[dict, ...]:
1009
+ - A dictionary with shape as key and list/batch of images with that shape as value
1010
+ - Zero or more dictionaries (one per argument in `*paired_inputs`) grouped consistently with `images`; these carry
1011
+ the corresponding per-item values and are not stacked
1012
+ - A dictionary mapping original indices to (shape, index) tuples
1013
+ """
1014
+ # If disable grouping is not explicitly provided, we favor disabling it if the images are on CPU, and enabling it otherwise.
1015
+ if disable_grouping is None:
1016
+ device = _get_device_from_images(images, is_nested)
1017
+ disable_grouping = device == "cpu"
1018
+
1019
+ if disable_grouping:
1020
+ grouped_images_index = {key: (key, 0) for key, _ in _iterate_items(images, is_nested)}
1021
+ if is_nested:
1022
+ grouped_images_index["_num_sublists"] = len(images)
1023
+
1024
+ return (
1025
+ {key: img.unsqueeze(0) for key, img in _iterate_items(images, is_nested)},
1026
+ *[
1027
+ {key: item.unsqueeze(0) for key, item in _iterate_items(paired_list, is_nested)}
1028
+ for paired_list in paired_inputs
1029
+ ],
1030
+ grouped_images_index,
1031
+ )
1032
+
1033
+ # Handle single level nested structure
1034
+ grouped_images, *paired_grouped_values, grouped_images_index = _group_images_by_shape(
1035
+ images, *paired_inputs, is_nested=is_nested
1036
+ )
1037
+
1038
+ # Stack images with the same shape
1039
+ grouped_images = {shape: torch.stack(images_list, dim=0) for shape, images_list in grouped_images.items()}
1040
+
1041
+ return grouped_images, *paired_grouped_values, grouped_images_index
1042
+
1043
+
1044
+ def reorder_images(
1045
+ processed_images: dict[tuple[int, int], "torch.Tensor"],
1046
+ grouped_images_index: dict[int | tuple[int, int], tuple[tuple[int, int], int]],
1047
+ is_nested: bool = False,
1048
+ ) -> Union[list["torch.Tensor"], "torch.Tensor"]:
1049
+ """
1050
+ Reconstructs images in the original order, preserving the original structure (nested or not).
1051
+ The input structure is either all flat or all nested.
1052
+
1053
+ Args:
1054
+ processed_images (dict[tuple[int, int], "torch.Tensor"]):
1055
+ Dictionary mapping shapes to batched processed images.
1056
+ grouped_images_index (dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]):
1057
+ Dictionary mapping original indices to (shape, index) tuples.
1058
+ is_nested (bool, *optional*, defaults to False):
1059
+ Whether the images are nested. Cannot be inferred from the input, as some processing functions outputs nested images.
1060
+ even with non nested images,e.g functions splitting images into patches. We thus can't deduce is_nested from the input.
1061
+
1062
+
1063
+ Returns:
1064
+ Union[list["torch.Tensor"], "torch.Tensor"]:
1065
+ Images in the original structure.
1066
+ """
1067
+ if not is_nested:
1068
+ return [
1069
+ processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
1070
+ for i in range(len(grouped_images_index))
1071
+ ]
1072
+
1073
+ return _reconstruct_nested_structure(grouped_images_index, processed_images)
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_gemma3 import *
22
+ from .image_processing_gemma3 import *
23
+ from .image_processing_pil_gemma3 import *
24
+ from .modeling_gemma3 import *
25
+ from .processing_gemma3 import *
26
+ else:
27
+ import sys
28
+
29
+ _file = globals()["__file__"]
30
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/configuration_gemma3.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_gemma3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ from typing import Any
22
+
23
+ from huggingface_hub.dataclasses import strict
24
+
25
+ from ...configuration_utils import PreTrainedConfig
26
+ from ...utils import auto_docstring, logging
27
+ from ..siglip import SiglipVisionConfig
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ @auto_docstring(checkpoint="google/gemma-3-4b-it")
34
+ @strict
35
+ class Gemma3TextConfig(PreTrainedConfig):
36
+ r"""
37
+ query_pre_attn_scalar (`float`, *optional*, defaults to 256):
38
+ scaling factor used on the attention scores
39
+ final_logit_softcapping (`float`, *optional*):
40
+ Scaling factor when applying tanh softcapping on the logits.
41
+ attn_logit_softcapping (`float`, *optional*):
42
+ Scaling factor when applying tanh softcapping on the attention scores.
43
+ use_bidirectional_attention (`bool`, *optional*, defaults to `False`):
44
+ If True, the model will attend to all text tokens instead of using a causal mask. This does not change
45
+ behavior for vision tokens.
46
+
47
+ ```python
48
+ >>> from transformers import Gemma3TextModel, Gemma3TextConfig
49
+ >>> # Initializing a Gemma3Text gemma3_text-7b style configuration
50
+ >>> configuration = Gemma3TextConfig()
51
+ >>> # Initializing a model from the gemma3_text-7b style configuration
52
+ >>> model = Gemma3TextModel(configuration)
53
+ >>> # Accessing the model configuration
54
+ >>> configuration = model.config
55
+ ```
56
+ """
57
+
58
+ model_type = "gemma3_text"
59
+ keys_to_ignore_at_inference = ["past_key_values"]
60
+ base_model_tp_plan = {
61
+ "layers.*.self_attn.q_proj": "colwise",
62
+ "layers.*.self_attn.k_proj": "colwise",
63
+ "layers.*.self_attn.v_proj": "colwise",
64
+ "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
65
+ "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
66
+ "layers.*.self_attn.o_proj": "rowwise",
67
+ "layers.*.mlp.gate_proj": "colwise",
68
+ "layers.*.mlp.up_proj": "colwise",
69
+ "layers.*.mlp.down_proj": "rowwise",
70
+ }
71
+ base_model_pp_plan = {
72
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
73
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
74
+ "norm": (["hidden_states"], ["hidden_states"]),
75
+ }
76
+
77
+ vocab_size: int = 262_208
78
+ hidden_size: int = 2304
79
+ intermediate_size: int = 9216
80
+ num_hidden_layers: int = 26
81
+ num_attention_heads: int = 8
82
+ num_key_value_heads: int = 4
83
+ head_dim: int = 256
84
+ hidden_activation: str = "gelu_pytorch_tanh"
85
+ max_position_embeddings: int = 131_072
86
+ initializer_range: float = 0.02
87
+ rms_norm_eps: float = 1e-6
88
+ use_cache: bool = True
89
+ pad_token_id: int | None = 0
90
+ eos_token_id: int | list[int] | None = 1
91
+ bos_token_id: int | None = 2
92
+ tie_word_embeddings: bool = True
93
+ rope_parameters: dict | None = None
94
+ attention_bias: bool = False
95
+ attention_dropout: int | float | None = 0.0
96
+ query_pre_attn_scalar: int = 256
97
+ sliding_window: int | None = 4096
98
+ layer_types: list[str] | None = None
99
+ final_logit_softcapping: float | None = None
100
+ attn_logit_softcapping: float | None = None
101
+ use_bidirectional_attention: bool | None = False
102
+ default_theta = {"global": 1_000_000.0, "local": 10_000.0}
103
+
104
+ def __post_init__(self, **kwargs):
105
+ if self.use_bidirectional_attention:
106
+ self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
107
+
108
+ # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
109
+ self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
110
+
111
+ if self.layer_types is None:
112
+ self.layer_types = [
113
+ "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
114
+ for i in range(self.num_hidden_layers)
115
+ ]
116
+
117
+ super().__post_init__(**kwargs)
118
+
119
+ def validate_architecture(self):
120
+ """Part of `@strict`-powered validation. Validates the architecture of the config."""
121
+ if self.hidden_size % self.num_attention_heads != 0:
122
+ raise ValueError(
123
+ f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
124
+ f"heads ({self.num_attention_heads})."
125
+ )
126
+
127
+ def convert_rope_params_to_dict(self, **kwargs):
128
+ rope_scaling = kwargs.pop("rope_scaling", None)
129
+
130
+ # Try to set `rope_scaling` if available, otherwise use `rope_parameters`. If we find `rope_parameters`
131
+ # as arg in the inputs, we can safely assume that it is in the new format. New naming used -> new format
132
+ default_rope_params = {
133
+ "sliding_attention": {"rope_type": "default"},
134
+ "full_attention": {"rope_type": "default"},
135
+ }
136
+ self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else default_rope_params
137
+ if rope_scaling is not None:
138
+ self.rope_parameters["full_attention"].update(rope_scaling)
139
+
140
+ # Set default values if not present
141
+ if self.rope_parameters.get("full_attention") is None:
142
+ self.rope_parameters["full_attention"] = {"rope_type": "default"}
143
+ self.rope_parameters["full_attention"].setdefault(
144
+ "rope_theta", kwargs.pop("rope_theta", self.default_theta["global"])
145
+ )
146
+ if self.rope_parameters.get("sliding_attention") is None:
147
+ self.rope_parameters["sliding_attention"] = {"rope_type": "default"}
148
+ self.rope_parameters["sliding_attention"].setdefault(
149
+ "rope_theta", kwargs.pop("rope_local_base_freq", self.default_theta["local"])
150
+ )
151
+
152
+ # Standardize and validate the correctness of rotary position embeddings parameters
153
+ self.standardize_rope_params()
154
+ return kwargs
155
+
156
+
157
+ @auto_docstring(checkpoint="google/gemma-3-4b-it")
158
+ @strict
159
+ class Gemma3Config(PreTrainedConfig):
160
+ r"""
161
+ mm_tokens_per_image (`int`, *optional*, defaults to 256):
162
+ The number of tokens per image embedding.
163
+ boi_token_index (`int`, *optional*, defaults to 255999):
164
+ The begin-of-image token index to wrap the image prompt.
165
+ eoi_token_index (`int`, *optional*, defaults to 256000):
166
+ The end-of-image token index to wrap the image prompt.
167
+
168
+ Example:
169
+
170
+ ```python
171
+ >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
172
+
173
+ >>> # Initializing a Siglip-like vision config
174
+ >>> vision_config = SiglipVisionConfig()
175
+
176
+ >>> # Initializing a Gemma3 Text config
177
+ >>> text_config = Gemma3TextConfig()
178
+
179
+ >>> # Initializing a Gemma3 gemma-3-4b style configuration
180
+ >>> configuration = Gemma3Config(vision_config, text_config)
181
+
182
+ >>> # Initializing a model from the gemma-3-4b style configuration
183
+ >>> model = Gemma3TextConfig(configuration)
184
+
185
+ >>> # Accessing the model configuration
186
+ >>> configuration = model.config
187
+ ```"""
188
+
189
+ model_type = "gemma3"
190
+ attribute_map = {
191
+ "image_token_id": "image_token_index",
192
+ "boi_token_id": "boi_token_index",
193
+ "eoi_token_id": "eoi_token_index",
194
+ }
195
+ sub_configs = {
196
+ "text_config": Gemma3TextConfig,
197
+ "vision_config": SiglipVisionConfig,
198
+ }
199
+
200
+ text_config: Gemma3TextConfig | dict[str, Any] | None = None
201
+ vision_config: SiglipVisionConfig | dict[str, Any] | None = None
202
+ mm_tokens_per_image: int | None = 256
203
+ boi_token_index: int | None = 255_999
204
+ eoi_token_index: int | None = 256_000
205
+ image_token_index: int | None = 262_144
206
+ initializer_range: float | None = 0.02
207
+ tie_word_embeddings: bool | None = True
208
+
209
+ def __post_init__(self, **kwargs):
210
+ if self.text_config is None:
211
+ self.text_config = Gemma3TextConfig()
212
+ logger.info("text_config is None, using default Gemma3TextConfig text config.")
213
+ elif isinstance(self.text_config, dict):
214
+ self.text_config = Gemma3TextConfig(**self.text_config)
215
+
216
+ if isinstance(self.vision_config, dict):
217
+ self.vision_config = SiglipVisionConfig(**self.vision_config)
218
+ elif self.vision_config is None:
219
+ self.vision_config = SiglipVisionConfig()
220
+ logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
221
+
222
+ super().__post_init__(**kwargs)
223
+
224
+
225
+ __all__ = ["Gemma3Config", "Gemma3TextConfig"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/image_processing_gemma3.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Image processor class for Gemma3."""
15
+
16
+ import itertools
17
+ import math
18
+
19
+ import torch
20
+ from torchvision.transforms.v2 import functional as tvF
21
+
22
+ from ...image_processing_backends import TorchvisionBackend
23
+ from ...image_processing_utils import BatchFeature
24
+ from ...image_transforms import group_images_by_shape, reorder_images
25
+ from ...image_utils import (
26
+ IMAGENET_STANDARD_MEAN,
27
+ IMAGENET_STANDARD_STD,
28
+ ImageInput,
29
+ PILImageResampling,
30
+ SizeDict,
31
+ )
32
+ from ...processing_utils import ImagesKwargs, Unpack
33
+ from ...utils import (
34
+ TensorType,
35
+ auto_docstring,
36
+ )
37
+
38
+
39
+ class Gemma3ImageProcessorKwargs(ImagesKwargs, total=False):
40
+ """
41
+ do_pan_and_scan (`bool`, *optional*):
42
+ Whether to apply `pan_and_scan` to images.
43
+ pan_and_scan_min_crop_size (`int`, *optional*):
44
+ Minimum size of each crop in pan and scan.
45
+ pan_and_scan_max_num_crops (`int`, *optional*):
46
+ Maximum number of crops per image in pan and scan.
47
+ pan_and_scan_min_ratio_to_activate (`float`, *optional*):
48
+ Minimum aspect ratio to activate pan and scan.
49
+ """
50
+
51
+ do_pan_and_scan: bool
52
+ pan_and_scan_min_crop_size: int
53
+ pan_and_scan_max_num_crops: int
54
+ pan_and_scan_min_ratio_to_activate: float
55
+
56
+
57
+ @auto_docstring
58
+ class Gemma3ImageProcessor(TorchvisionBackend):
59
+ resample = PILImageResampling.BILINEAR
60
+ image_mean = IMAGENET_STANDARD_MEAN
61
+ image_std = IMAGENET_STANDARD_STD
62
+ size = {"height": 224, "width": 224}
63
+ default_to_square = True
64
+ do_convert_rgb = True
65
+ do_resize = True
66
+ do_rescale = True
67
+ do_normalize = True
68
+ do_pan_and_scan = None
69
+ pan_and_scan_min_crop_size = None
70
+ pan_and_scan_max_num_crops = None
71
+ pan_and_scan_min_ratio_to_activate = None
72
+ valid_kwargs = Gemma3ImageProcessorKwargs
73
+ model_input_names = ["pixel_values", "num_crops"]
74
+
75
+ def __init__(self, **kwargs: Unpack[Gemma3ImageProcessorKwargs]):
76
+ super().__init__(**kwargs)
77
+
78
+ @auto_docstring
79
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[Gemma3ImageProcessorKwargs]) -> BatchFeature:
80
+ return super().preprocess(images, **kwargs)
81
+
82
+ def pan_and_scan_batched(
83
+ self,
84
+ images: "torch.Tensor",
85
+ pan_and_scan_min_crop_size: int,
86
+ pan_and_scan_max_num_crops: int,
87
+ pan_and_scan_min_ratio_to_activate: float,
88
+ ):
89
+ """
90
+ Pan and Scan an image, by cropping into smaller images when the aspect ratio exceeds
91
+ minimum allowed ratio.
92
+
93
+ Args:
94
+ images (`torch.Tensor`):
95
+ Image to resize.
96
+ pan_and_scan_min_crop_size (`int`, *optional*):
97
+ Minimum size of each crop in pan and scan.
98
+ pan_and_scan_max_num_crops (`int`, *optional*):
99
+ Maximum number of crops per image in pan and scan.
100
+ pan_and_scan_min_ratio_to_activate (`float`, *optional*):
101
+ Minimum aspect ratio to activate pan and scan.
102
+ """
103
+ height, width = images.shape[-2:]
104
+
105
+ # Square or landscape image.
106
+ if width >= height:
107
+ # Only apply PaS if the image is sufficiently exaggerated
108
+ if width / height < pan_and_scan_min_ratio_to_activate:
109
+ return []
110
+
111
+ # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
112
+ num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding.
113
+ num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w)
114
+
115
+ # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
116
+ num_crops_w = max(2, num_crops_w)
117
+ num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
118
+ num_crops_h = 1
119
+
120
+ # Portrait image.
121
+ else:
122
+ # Only apply PaS if the image is sufficiently exaggerated
123
+ if height / width < pan_and_scan_min_ratio_to_activate:
124
+ return []
125
+
126
+ # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
127
+ num_crops_h = int(math.floor(height / width + 0.5))
128
+ num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h)
129
+
130
+ # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
131
+ num_crops_h = max(2, num_crops_h)
132
+ num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
133
+ num_crops_w = 1
134
+
135
+ crop_size_w = int(math.ceil(width / num_crops_w))
136
+ crop_size_h = int(math.ceil(height / num_crops_h))
137
+
138
+ # Don't apply PaS if crop size is too small.
139
+ if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
140
+ return []
141
+
142
+ crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
143
+ crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
144
+
145
+ return [
146
+ images[..., pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
147
+ for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
148
+ ]
149
+
150
+ def _process_images_for_pan_and_scan(
151
+ self,
152
+ images: list["torch.Tensor"],
153
+ do_pan_and_scan: bool,
154
+ pan_and_scan_min_crop_size: int,
155
+ pan_and_scan_max_num_crops: int,
156
+ pan_and_scan_min_ratio_to_activate: float,
157
+ ):
158
+ pas_images = self.pan_and_scan_batched(
159
+ images=images,
160
+ pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
161
+ pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
162
+ pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
163
+ )
164
+ num_crops = [len(pas_images) for _ in images]
165
+ return pas_images, num_crops
166
+
167
+ def _preprocess(
168
+ self,
169
+ images: list["torch.Tensor"],
170
+ do_resize: bool,
171
+ size: SizeDict,
172
+ resample: "PILImageResampling | tvF.InterpolationMode | int | None",
173
+ do_rescale: bool,
174
+ rescale_factor: float,
175
+ do_normalize: bool,
176
+ image_mean: float | list[float] | None,
177
+ image_std: float | list[float] | None,
178
+ disable_grouping: bool | None,
179
+ return_tensors: str | TensorType | None,
180
+ do_pan_and_scan: bool | None = None,
181
+ pan_and_scan_min_crop_size: int | None = None,
182
+ pan_and_scan_max_num_crops: int | None = None,
183
+ pan_and_scan_min_ratio_to_activate: float | None = None,
184
+ **kwargs,
185
+ ) -> BatchFeature:
186
+ # Group images by size for batched processing
187
+ processed_images_grouped = {}
188
+ num_crops_grouped = {}
189
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
190
+ for shape_images, stacked_images in grouped_images.items():
191
+ if do_pan_and_scan:
192
+ pas_images, num_crops = self._process_images_for_pan_and_scan(
193
+ images=stacked_images,
194
+ do_pan_and_scan=do_pan_and_scan,
195
+ pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
196
+ pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
197
+ pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
198
+ )
199
+ # Add the thumbnails to the image patches
200
+ stacked_images = [stacked_images] + pas_images
201
+ # Group images by size for batched resizing (this will typically group thumbnails together and cropped patches together)
202
+ processed_image_patches_grouped = {}
203
+ grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
204
+ stacked_images, disable_grouping=disable_grouping
205
+ )
206
+ for shape, stacked_image_patches in grouped_image_patches.items():
207
+ stacked_image_patches = self.resize(
208
+ image=stacked_image_patches,
209
+ size=size,
210
+ resample=resample,
211
+ )
212
+ processed_image_patches_grouped[shape] = stacked_image_patches
213
+ processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
214
+ # Transpose to have the thumbnails with their corresponding patches
215
+ stacked_images = torch.stack(processed_image_patches, dim=0).transpose(0, 1).contiguous()
216
+ else:
217
+ num_crops = [0 for _ in stacked_images]
218
+
219
+ if do_resize:
220
+ stacked_images = self.resize(
221
+ image=stacked_images,
222
+ size=size,
223
+ resample=resample,
224
+ )
225
+ num_crops_grouped[shape_images] = num_crops
226
+ processed_images_grouped[shape_images] = stacked_images
227
+ resized_images = reorder_images(processed_images_grouped, grouped_images_index)
228
+ # If pan and scan is enabled, we need to flatten the list of images
229
+ if do_pan_and_scan:
230
+ resized_images = [image for images_list in resized_images for image in images_list]
231
+ num_crops = reorder_images(num_crops_grouped, grouped_images_index)
232
+
233
+ # Group images by size for further processing
234
+ # Needed in case do_resize is False, or resize returns images with different sizes
235
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
236
+ processed_images_grouped = {}
237
+ for shape, stacked_images in grouped_images.items():
238
+ # Fused rescale and normalize
239
+ stacked_images = self.rescale_and_normalize(
240
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
241
+ )
242
+ processed_images_grouped[shape] = stacked_images
243
+
244
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
245
+ return BatchFeature(
246
+ data={"pixel_values": processed_images, "num_crops": num_crops}, tensor_type=return_tensors
247
+ )
248
+
249
+
250
+ __all__ = ["Gemma3ImageProcessor"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/image_processing_pil_gemma3.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Image processor class for Gemma3."""
15
+
16
+ import itertools
17
+ import math
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_backends import PilBackend
22
+ from ...image_processing_utils import BatchFeature
23
+ from ...image_utils import (
24
+ IMAGENET_STANDARD_MEAN,
25
+ IMAGENET_STANDARD_STD,
26
+ ImageInput,
27
+ PILImageResampling,
28
+ SizeDict,
29
+ get_image_size,
30
+ )
31
+ from ...processing_utils import ImagesKwargs, Unpack
32
+ from ...utils import (
33
+ TensorType,
34
+ auto_docstring,
35
+ )
36
+
37
+
38
+ # Adapted from transformers.models.gemma3.image_processing_gemma3.Gemma3ImageProcessorKwargs
39
+ class Gemma3ImageProcessorKwargs(ImagesKwargs, total=False):
40
+ """
41
+ do_pan_and_scan (`bool`, *optional*):
42
+ Whether to apply `pan_and_scan` to images.
43
+ pan_and_scan_min_crop_size (`int`, *optional*):
44
+ Minimum size of each crop in pan and scan.
45
+ pan_and_scan_max_num_crops (`int`, *optional*):
46
+ Maximum number of crops per image in pan and scan.
47
+ pan_and_scan_min_ratio_to_activate (`float`, *optional*):
48
+ Minimum aspect ratio to activate pan and scan.
49
+ """
50
+
51
+ do_pan_and_scan: bool
52
+ pan_and_scan_min_crop_size: int
53
+ pan_and_scan_max_num_crops: int
54
+ pan_and_scan_min_ratio_to_activate: float
55
+
56
+
57
+ @auto_docstring
58
+ class Gemma3ImageProcessorPil(PilBackend):
59
+ resample = PILImageResampling.BILINEAR
60
+ image_mean = IMAGENET_STANDARD_MEAN
61
+ image_std = IMAGENET_STANDARD_STD
62
+ size = {"height": 224, "width": 224}
63
+ default_to_square = True
64
+ do_convert_rgb = True
65
+ do_resize = True
66
+ do_rescale = True
67
+ do_normalize = True
68
+ do_pan_and_scan = None
69
+ pan_and_scan_min_crop_size = None
70
+ pan_and_scan_max_num_crops = None
71
+ pan_and_scan_min_ratio_to_activate = None
72
+ valid_kwargs = Gemma3ImageProcessorKwargs
73
+ model_input_names = ["pixel_values", "num_crops"]
74
+
75
+ def __init__(self, **kwargs: Unpack[Gemma3ImageProcessorKwargs]):
76
+ super().__init__(**kwargs)
77
+
78
+ @auto_docstring
79
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[Gemma3ImageProcessorKwargs]) -> BatchFeature:
80
+ return super().preprocess(images, **kwargs)
81
+
82
+ def pan_and_scan(
83
+ self,
84
+ image: np.ndarray,
85
+ pan_and_scan_min_crop_size: int,
86
+ pan_and_scan_max_num_crops: int,
87
+ pan_and_scan_min_ratio_to_activate: float,
88
+ ):
89
+ """
90
+ Pan and Scan an image, by cropping into smaller images when the aspect ratio exceeds
91
+ minimum allowed ratio.
92
+
93
+ Args:
94
+ image (`np.ndarray`):
95
+ Image to resize.
96
+ pan_and_scan_min_crop_size (`int`, *optional*):
97
+ Minimum size of each crop in pan and scan.
98
+ pan_and_scan_max_num_crops (`int`, *optional*):
99
+ Maximum number of crops per image in pan and scan.
100
+ pan_and_scan_min_ratio_to_activate (`float`, *optional*):
101
+ Minimum aspect ratio to activate pan and scan.
102
+ """
103
+ height, width = get_image_size(image, channel_dim="channels_first")
104
+
105
+ # Square or landscape image.
106
+ if width >= height:
107
+ # Only apply PaS if the image is sufficiently exaggerated
108
+ if width / height < pan_and_scan_min_ratio_to_activate:
109
+ return []
110
+
111
+ # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
112
+ num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding.
113
+ num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w)
114
+
115
+ # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
116
+ num_crops_w = max(2, num_crops_w)
117
+ num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
118
+ num_crops_h = 1
119
+
120
+ # Portrait image.
121
+ else:
122
+ # Only apply PaS if the image is sufficiently exaggerated
123
+ if height / width < pan_and_scan_min_ratio_to_activate:
124
+ return []
125
+
126
+ # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
127
+ num_crops_h = int(math.floor(height / width + 0.5))
128
+ num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h)
129
+
130
+ # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
131
+ num_crops_h = max(2, num_crops_h)
132
+ num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
133
+ num_crops_w = 1
134
+
135
+ crop_size_w = int(math.ceil(width / num_crops_w))
136
+ crop_size_h = int(math.ceil(height / num_crops_h))
137
+
138
+ # Don't apply PaS if crop size is too small.
139
+ if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
140
+ return []
141
+
142
+ crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
143
+ crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
144
+
145
+ # Images are channels-first (CHW format)
146
+ return [
147
+ image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
148
+ for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
149
+ ]
150
+
151
+ def _process_images_for_pan_and_scan(
152
+ self,
153
+ images: list[np.ndarray],
154
+ do_pan_and_scan: bool,
155
+ pan_and_scan_min_crop_size: int,
156
+ pan_and_scan_max_num_crops: int,
157
+ pan_and_scan_min_ratio_to_activate: float,
158
+ ):
159
+ pas_images_list = []
160
+ num_crops = []
161
+ for image in images:
162
+ pas_images = self.pan_and_scan(
163
+ image=image,
164
+ pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
165
+ pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
166
+ pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
167
+ )
168
+ pas_images_list.extend([image] + pas_images)
169
+ num_crops.append(len(pas_images))
170
+ return pas_images_list, num_crops
171
+
172
+ def _preprocess(
173
+ self,
174
+ images: list[np.ndarray],
175
+ do_resize: bool,
176
+ size: SizeDict,
177
+ resample: "PILImageResampling | None",
178
+ do_rescale: bool,
179
+ rescale_factor: float,
180
+ do_normalize: bool,
181
+ image_mean: float | list[float] | None,
182
+ image_std: float | list[float] | None,
183
+ return_tensors: str | TensorType | None,
184
+ do_pan_and_scan: bool | None = None,
185
+ pan_and_scan_min_crop_size: int | None = None,
186
+ pan_and_scan_max_num_crops: int | None = None,
187
+ pan_and_scan_min_ratio_to_activate: float | None = None,
188
+ **kwargs,
189
+ ) -> BatchFeature:
190
+ processed_images = []
191
+ num_crops = []
192
+
193
+ for image in images:
194
+ if do_pan_and_scan:
195
+ pas_images = self.pan_and_scan(
196
+ image=image,
197
+ pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
198
+ pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
199
+ pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
200
+ )
201
+ # Add the original image and its crops
202
+ image_list = [image] + pas_images
203
+ num_crops.append(len(pas_images))
204
+ else:
205
+ image_list = [image]
206
+ num_crops.append(0)
207
+
208
+ # Process each image (original + crops if pan_and_scan)
209
+ processed_image_list = []
210
+ for img in image_list:
211
+ if do_resize:
212
+ img = self.resize(image=img, size=size, resample=resample)
213
+ if do_rescale:
214
+ img = self.rescale(image=img, scale=rescale_factor)
215
+ if do_normalize:
216
+ img = self.normalize(image=img, mean=image_mean, std=image_std)
217
+ processed_image_list.append(img)
218
+ processed_images.extend(processed_image_list)
219
+
220
+ return BatchFeature(
221
+ data={"pixel_values": processed_images, "num_crops": num_crops}, tensor_type=return_tensors
222
+ )
223
+
224
+
225
+ __all__ = ["Gemma3ImageProcessorPil"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py ADDED
@@ -0,0 +1,1118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_gemma3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ from collections.abc import Callable
22
+ from dataclasses import dataclass
23
+ from typing import Optional
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+ from ... import initialization as init
29
+ from ...activations import ACT2FN
30
+ from ...cache_utils import Cache, DynamicCache
31
+ from ...configuration_utils import PreTrainedConfig
32
+ from ...generation import GenerationMixin
33
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
34
+ from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
35
+ from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
36
+ from ...modeling_outputs import (
37
+ BaseModelOutputWithPast,
38
+ BaseModelOutputWithPooling,
39
+ CausalLMOutputWithPast,
40
+ SequenceClassifierOutputWithPast,
41
+ )
42
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
43
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
44
+ from ...processing_utils import Unpack
45
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
46
+ from ...utils.generic import maybe_autocast, merge_with_config_defaults
47
+ from ...utils.output_capturing import capture_outputs
48
+ from ..auto import AutoModel
49
+ from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
50
+
51
+
52
+ @auto_docstring(
53
+ custom_intro="""
54
+ Base class for Gemma3 outputs, with hidden states and attentions.
55
+ """
56
+ )
57
+ @dataclass
58
+ class Gemma3ModelOutputWithPast(BaseModelOutputWithPast):
59
+ r"""
60
+ image_hidden_states (`torch.FloatTensor`, *optional*):
61
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
62
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
63
+ """
64
+
65
+ image_hidden_states: torch.FloatTensor | None = None
66
+
67
+
68
+ @auto_docstring(
69
+ custom_intro="""
70
+ Base class for Gemma3 causal language model (or autoregressive) outputs.
71
+ """
72
+ )
73
+ @dataclass
74
+ class Gemma3CausalLMOutputWithPast(ModelOutput):
75
+ r"""
76
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
77
+ Language modeling loss (for next-token prediction).
78
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
79
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
80
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
81
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
82
+
83
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
84
+ `past_key_values` input) to speed up sequential decoding.
85
+ image_hidden_states (`torch.FloatTensor`, *optional*):
86
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
87
+ image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
88
+ """
89
+
90
+ loss: torch.FloatTensor | None = None
91
+ logits: torch.FloatTensor | None = None
92
+ past_key_values: Cache | None = None
93
+ hidden_states: tuple[torch.FloatTensor] | None = None
94
+ attentions: tuple[torch.FloatTensor] | None = None
95
+ image_hidden_states: torch.FloatTensor | None = None
96
+
97
+
98
+ class Gemma3TextScaledWordEmbedding(nn.Embedding):
99
+ """
100
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
101
+ """
102
+
103
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
104
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
105
+ self.scalar_embed_scale = embed_scale
106
+ self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
107
+
108
+ def forward(self, input_ids: torch.Tensor):
109
+ return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
110
+
111
+
112
+ class Gemma3MLP(nn.Module):
113
+ def __init__(self, config: Gemma3TextConfig):
114
+ super().__init__()
115
+ self.config = config
116
+ self.hidden_size = config.hidden_size
117
+ self.intermediate_size = config.intermediate_size
118
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
119
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
120
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
121
+ self.act_fn = ACT2FN[config.hidden_activation]
122
+
123
+ def forward(self, x):
124
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
125
+ return down_proj
126
+
127
+
128
+ class Gemma3RMSNorm(nn.Module):
129
+ def __init__(self, dim: int, eps: float = 1e-6):
130
+ super().__init__()
131
+ self.eps = eps
132
+ self.weight = nn.Parameter(torch.zeros(dim))
133
+
134
+ def _norm(self, x):
135
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
136
+
137
+ def forward(self, x):
138
+ output = self._norm(x.float())
139
+ # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
140
+ # See https://github.com/huggingface/transformers/pull/29402
141
+ output = output * (1.0 + self.weight.float())
142
+ return output.type_as(x)
143
+
144
+ def extra_repr(self):
145
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
146
+
147
+
148
+ class Gemma3RotaryEmbedding(nn.Module):
149
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
150
+
151
+ def __init__(self, config: Gemma3TextConfig):
152
+ super().__init__()
153
+ self.max_seq_len_cached = config.max_position_embeddings
154
+ self.original_max_seq_len = config.max_position_embeddings
155
+ self.config = config
156
+ self.layer_types = list(set(config.layer_types))
157
+ self.rope_type = {}
158
+ for layer_type in self.layer_types:
159
+ rope_params = self.config.rope_parameters[layer_type]
160
+ if rope_params is None:
161
+ continue
162
+
163
+ self.rope_type[layer_type] = rope_params["rope_type"]
164
+ rope_init_fn: Callable = self.compute_default_rope_parameters
165
+ if self.rope_type[layer_type] != "default":
166
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
167
+ curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, layer_type=layer_type)
168
+ self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
169
+ self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
170
+ setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
171
+
172
+ @staticmethod
173
+ def compute_default_rope_parameters(
174
+ config: Gemma3TextConfig | None = None,
175
+ device: Optional["torch.device"] = None,
176
+ seq_len: int | None = None,
177
+ layer_type: str | None = None,
178
+ ) -> tuple["torch.Tensor", float]:
179
+ """
180
+ Computes the inverse frequencies according to the original RoPE implementation
181
+ Args:
182
+ config ([`~transformers.PreTrainedConfig`]):
183
+ The model configuration.
184
+ device (`torch.device`):
185
+ The device to use for initialization of the inverse frequencies.
186
+ seq_len (`int`, *optional*):
187
+ The current sequence length. Unused for this type of RoPE.
188
+ layer_type (`str`, *optional*):
189
+ The current layer type if the model has different RoPE parameters per type.
190
+ Should not be used unless `config.layer_types is not None`
191
+
192
+ Returns:
193
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
194
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
195
+ """
196
+ # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
197
+ base = config.rope_parameters[layer_type]["rope_theta"]
198
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
199
+
200
+ attention_factor = 1.0 # Unused in this type of RoPE
201
+
202
+ # Compute the inverse frequencies
203
+ inv_freq = 1.0 / (
204
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
205
+ )
206
+ return inv_freq, attention_factor
207
+
208
+ @torch.no_grad()
209
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
210
+ def forward(self, x, position_ids, layer_type=None):
211
+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
212
+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
213
+
214
+ inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
215
+ position_ids_expanded = position_ids[:, None, :].float()
216
+
217
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
218
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
219
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
220
+ emb = torch.cat((freqs, freqs), dim=-1)
221
+ cos = emb.cos() * attention_scaling
222
+ sin = emb.sin() * attention_scaling
223
+
224
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
225
+
226
+
227
+ def rotate_half(x):
228
+ """Rotates half the hidden dims of the input."""
229
+ x1 = x[..., : x.shape[-1] // 2]
230
+ x2 = x[..., x.shape[-1] // 2 :]
231
+ return torch.cat((-x2, x1), dim=-1)
232
+
233
+
234
+ @use_kernel_func_from_hub("rotary_pos_emb")
235
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
236
+ """Applies Rotary Position Embedding to the query and key tensors.
237
+
238
+ Args:
239
+ q (`torch.Tensor`): The query tensor.
240
+ k (`torch.Tensor`): The key tensor.
241
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
242
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
243
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
244
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
245
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
246
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
247
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
248
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
249
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
250
+ Returns:
251
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
252
+ """
253
+ cos = cos.unsqueeze(unsqueeze_dim)
254
+ sin = sin.unsqueeze(unsqueeze_dim)
255
+ q_embed = (q * cos) + (rotate_half(q) * sin)
256
+ k_embed = (k * cos) + (rotate_half(k) * sin)
257
+ return q_embed, k_embed
258
+
259
+
260
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
261
+ """
262
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
263
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
264
+ """
265
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
266
+ if n_rep == 1:
267
+ return hidden_states
268
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
269
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
270
+
271
+
272
+ def eager_attention_forward(
273
+ module: nn.Module,
274
+ query: torch.Tensor,
275
+ key: torch.Tensor,
276
+ value: torch.Tensor,
277
+ attention_mask: torch.Tensor | None,
278
+ dropout: float | int = 0.0,
279
+ scaling: float | None = None,
280
+ softcap: float | None = None,
281
+ **kwargs,
282
+ ) -> tuple[torch.Tensor, torch.Tensor]:
283
+ if scaling is None:
284
+ scaling = module.head_dim**-0.5
285
+
286
+ key_states = repeat_kv(key, module.num_key_value_groups)
287
+ value_states = repeat_kv(value, module.num_key_value_groups)
288
+
289
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
290
+
291
+ if softcap is not None:
292
+ attn_weights = attn_weights / softcap
293
+ attn_weights = torch.tanh(attn_weights)
294
+ attn_weights = attn_weights * softcap
295
+ if attention_mask is not None:
296
+ attn_weights = attn_weights + attention_mask
297
+
298
+ # upcast attention to fp32
299
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
300
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
301
+ attn_output = torch.matmul(attn_weights, value_states)
302
+ attn_output = attn_output.transpose(1, 2).contiguous()
303
+ return attn_output, attn_weights
304
+
305
+
306
+ @use_kernelized_func(apply_rotary_pos_emb)
307
+ class Gemma3Attention(nn.Module):
308
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
309
+
310
+ def __init__(self, config: Gemma3TextConfig, layer_idx: int):
311
+ super().__init__()
312
+ self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
313
+ self.config = config
314
+ self.layer_idx = layer_idx
315
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
316
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
317
+ self.scaling = config.query_pre_attn_scalar**-0.5
318
+ self.attention_dropout = self.config.attention_dropout
319
+ self.is_causal = not self.config.use_bidirectional_attention
320
+
321
+ self.q_proj = nn.Linear(
322
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
323
+ )
324
+ self.k_proj = nn.Linear(
325
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
326
+ )
327
+ self.v_proj = nn.Linear(
328
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
329
+ )
330
+ self.o_proj = nn.Linear(
331
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
332
+ )
333
+ self.attn_logit_softcapping = self.config.attn_logit_softcapping
334
+ self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
335
+ self.is_sliding = self.layer_type == "sliding_attention"
336
+
337
+ self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
338
+ self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
339
+
340
+ def forward(
341
+ self,
342
+ hidden_states: torch.Tensor,
343
+ position_embeddings: torch.Tensor = None,
344
+ attention_mask: torch.Tensor | None = None,
345
+ past_key_values: Cache | None = None,
346
+ **kwargs: Unpack[TransformersKwargs],
347
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
348
+ input_shape = hidden_states.shape[:-1]
349
+ hidden_shape = (*input_shape, -1, self.head_dim)
350
+
351
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
352
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
353
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
354
+
355
+ query_states = self.q_norm(query_states)
356
+ key_states = self.k_norm(key_states)
357
+
358
+ cos, sin = position_embeddings
359
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
360
+
361
+ if past_key_values is not None:
362
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
363
+
364
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
365
+ self.config._attn_implementation, eager_attention_forward
366
+ )
367
+
368
+ attn_output, attn_weights = attention_interface(
369
+ self,
370
+ query_states,
371
+ key_states,
372
+ value_states,
373
+ attention_mask,
374
+ dropout=self.attention_dropout if self.training else 0.0,
375
+ scaling=self.scaling,
376
+ sliding_window=self.sliding_window,
377
+ **kwargs,
378
+ )
379
+
380
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
381
+ attn_output = self.o_proj(attn_output)
382
+ return attn_output, attn_weights
383
+
384
+
385
+ class Gemma3DecoderLayer(GradientCheckpointingLayer):
386
+ def __init__(self, config: Gemma3TextConfig, layer_idx: int):
387
+ super().__init__()
388
+ self.config = config
389
+ self.hidden_size = config.hidden_size
390
+ self.layer_idx = layer_idx
391
+ self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
392
+ self.mlp = Gemma3MLP(config)
393
+ self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
394
+ self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
395
+ self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
396
+ self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
397
+
398
+ def forward(
399
+ self,
400
+ hidden_states: torch.Tensor,
401
+ position_embeddings: torch.Tensor = None,
402
+ attention_mask: torch.Tensor | None = None,
403
+ position_ids: torch.LongTensor | None = None,
404
+ past_key_values: Cache | None = None,
405
+ **kwargs: Unpack[TransformersKwargs],
406
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
407
+ residual = hidden_states
408
+
409
+ hidden_states = self.input_layernorm(hidden_states)
410
+
411
+ hidden_states, _ = self.self_attn(
412
+ hidden_states=hidden_states,
413
+ position_embeddings=position_embeddings,
414
+ attention_mask=attention_mask,
415
+ position_ids=position_ids,
416
+ past_key_values=past_key_values,
417
+ **kwargs,
418
+ )
419
+ hidden_states = self.post_attention_layernorm(hidden_states)
420
+ hidden_states = residual + hidden_states
421
+
422
+ residual = hidden_states
423
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
424
+ hidden_states = self.mlp(hidden_states)
425
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
426
+ hidden_states = residual + hidden_states
427
+
428
+ return hidden_states
429
+
430
+
431
+ @auto_docstring
432
+ class Gemma3PreTrainedModel(PreTrainedModel):
433
+ config: Gemma3Config
434
+ base_model_prefix = "model"
435
+ supports_gradient_checkpointing = True
436
+ _no_split_modules = [
437
+ "Gemma3DecoderLayer",
438
+ "SiglipVisionEmbeddings",
439
+ "SiglipEncoderLayer",
440
+ "SiglipMultiheadAttentionPoolingHead",
441
+ ]
442
+ _skip_keys_device_placement = ["past_key_values"]
443
+ _supports_flash_attn = True
444
+ _supports_sdpa = True
445
+ _supports_flex_attn = True
446
+
447
+ _can_compile_fullgraph = True
448
+ _supports_attention_backend = True
449
+ _can_record_outputs = {
450
+ "hidden_states": Gemma3DecoderLayer,
451
+ "attentions": Gemma3Attention,
452
+ }
453
+ input_modalities = ("image", "text")
454
+
455
+ @torch.no_grad()
456
+ def _init_weights(self, module):
457
+ super()._init_weights(module)
458
+ if isinstance(module, Gemma3MultiModalProjector):
459
+ init.zeros_(module.mm_input_projection_weight)
460
+ # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
461
+ elif "RMSNorm" in module.__class__.__name__:
462
+ init.zeros_(module.weight)
463
+ elif isinstance(module, Gemma3TextScaledWordEmbedding):
464
+ init.constant_(module.embed_scale, module.scalar_embed_scale)
465
+ elif isinstance(module, Gemma3RotaryEmbedding):
466
+ for layer_type in module.layer_types:
467
+ rope_init_fn = module.compute_default_rope_parameters
468
+ if module.rope_type[layer_type] != "default":
469
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
470
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
471
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
472
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
473
+
474
+
475
+ def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
476
+ """
477
+ Enables a bidirectional mask within the sliding window.
478
+ """
479
+
480
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
481
+ """A token can attend to any other token if their absolute distance is within
482
+ the (exclusive) sliding window size (distance < sliding_window)."""
483
+ return abs(q_idx - kv_idx) < sliding_window
484
+
485
+ return inner_mask
486
+
487
+
488
+ @auto_docstring
489
+ class Gemma3TextModel(Gemma3PreTrainedModel):
490
+ config: Gemma3TextConfig
491
+ input_modalities = ("text",)
492
+
493
+ def __init__(self, config: Gemma3TextConfig):
494
+ super().__init__(config)
495
+ self.padding_idx = config.pad_token_id
496
+ self.vocab_size = config.vocab_size
497
+
498
+ # Gemma3 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
499
+ self.embed_tokens = Gemma3TextScaledWordEmbedding(
500
+ config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
501
+ )
502
+ self.layers = nn.ModuleList(
503
+ [Gemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
504
+ )
505
+ self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
506
+ self.rotary_emb = Gemma3RotaryEmbedding(config)
507
+ self.gradient_checkpointing = False
508
+
509
+ # Initialize weights and apply final processing
510
+ self.post_init()
511
+
512
+ @merge_with_config_defaults
513
+ @capture_outputs
514
+ @auto_docstring
515
+ def forward(
516
+ self,
517
+ input_ids: torch.LongTensor | None = None,
518
+ attention_mask: torch.Tensor | None = None,
519
+ position_ids: torch.LongTensor | None = None,
520
+ past_key_values: Cache | None = None,
521
+ inputs_embeds: torch.FloatTensor | None = None,
522
+ use_cache: bool | None = None,
523
+ **kwargs: Unpack[TransformersKwargs],
524
+ ) -> BaseModelOutputWithPast:
525
+ if (input_ids is None) ^ (inputs_embeds is not None):
526
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
527
+
528
+ if inputs_embeds is None:
529
+ inputs_embeds = self.embed_tokens(input_ids)
530
+
531
+ if use_cache and past_key_values is None:
532
+ past_key_values = DynamicCache(config=self.config)
533
+
534
+ if position_ids is None:
535
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
536
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
537
+ position_ids = position_ids.unsqueeze(0)
538
+
539
+ # It may already have been prepared by e.g. `generate`
540
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
541
+ # Prepare mask arguments
542
+ mask_kwargs = {
543
+ "config": self.config,
544
+ "inputs_embeds": inputs_embeds,
545
+ "attention_mask": attention_mask,
546
+ "past_key_values": past_key_values,
547
+ "position_ids": position_ids,
548
+ }
549
+ sliding_mask_kwargs = mask_kwargs.copy()
550
+
551
+ if self.config.use_bidirectional_attention:
552
+ mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
553
+ sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window)
554
+
555
+ # Create the masks
556
+ causal_mask_mapping = {
557
+ "full_attention": create_causal_mask(**mask_kwargs),
558
+ "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
559
+ }
560
+
561
+ # embed positions
562
+ hidden_states = inputs_embeds
563
+ position_embeddings = {}
564
+ for layer_type in set(self.config.layer_types):
565
+ position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
566
+
567
+ for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
568
+ hidden_states = decoder_layer(
569
+ hidden_states,
570
+ attention_mask=causal_mask_mapping[self.config.layer_types[i]],
571
+ position_embeddings=position_embeddings[self.config.layer_types[i]],
572
+ position_ids=position_ids,
573
+ past_key_values=past_key_values,
574
+ **kwargs,
575
+ )
576
+
577
+ hidden_states = self.norm(hidden_states)
578
+
579
+ return BaseModelOutputWithPast(
580
+ last_hidden_state=hidden_states,
581
+ past_key_values=past_key_values,
582
+ )
583
+
584
+
585
+ @auto_docstring
586
+ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
587
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
588
+ _tp_plan = {"lm_head": "colwise_gather_output"}
589
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
590
+ config: Gemma3TextConfig
591
+
592
+ def __init__(self, config: Gemma3TextConfig):
593
+ super().__init__(config)
594
+ self.model = Gemma3TextModel(config)
595
+ self.vocab_size = config.vocab_size
596
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
597
+
598
+ # Initialize weights and apply final processing
599
+ self.post_init()
600
+
601
+ @can_return_tuple
602
+ @auto_docstring
603
+ def forward(
604
+ self,
605
+ input_ids: torch.LongTensor | None = None,
606
+ attention_mask: torch.Tensor | None = None,
607
+ position_ids: torch.LongTensor | None = None,
608
+ past_key_values: Cache | None = None,
609
+ inputs_embeds: torch.FloatTensor | None = None,
610
+ labels: torch.LongTensor | None = None,
611
+ use_cache: bool | None = None,
612
+ logits_to_keep: int | torch.Tensor = 0,
613
+ **kwargs: Unpack[TransformersKwargs],
614
+ ) -> CausalLMOutputWithPast:
615
+ r"""
616
+ Example:
617
+
618
+ ```python
619
+ >>> from transformers import AutoTokenizer, Gemma3ForCausalLM
620
+
621
+ >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
622
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
623
+
624
+ >>> prompt = "What is your favorite condiment?"
625
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
626
+
627
+ >>> # Generate
628
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
629
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
630
+ "What is your favorite condiment?"
631
+ ```"""
632
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
633
+ outputs: BaseModelOutputWithPast = self.model(
634
+ input_ids=input_ids,
635
+ attention_mask=attention_mask,
636
+ position_ids=position_ids,
637
+ past_key_values=past_key_values,
638
+ inputs_embeds=inputs_embeds,
639
+ use_cache=use_cache,
640
+ **kwargs,
641
+ )
642
+
643
+ hidden_states = outputs.last_hidden_state
644
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
645
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
646
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
647
+ if self.config.final_logit_softcapping is not None:
648
+ logits = logits / self.config.final_logit_softcapping
649
+ logits = torch.tanh(logits)
650
+ logits = logits * self.config.final_logit_softcapping
651
+
652
+ loss = None
653
+ if labels is not None:
654
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
655
+
656
+ return CausalLMOutputWithPast(
657
+ loss=loss,
658
+ logits=logits,
659
+ past_key_values=outputs.past_key_values,
660
+ hidden_states=outputs.hidden_states,
661
+ attentions=outputs.attentions,
662
+ )
663
+
664
+
665
+ class Gemma3MultiModalProjector(nn.Module):
666
+ def __init__(self, config: Gemma3Config):
667
+ super().__init__()
668
+
669
+ self.mm_input_projection_weight = nn.Parameter(
670
+ torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
671
+ )
672
+
673
+ self.mm_soft_emb_norm = Gemma3RMSNorm(
674
+ config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
675
+ )
676
+
677
+ self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
678
+ self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
679
+ self.kernel_size = self.patches_per_image // self.tokens_per_side
680
+ self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
681
+
682
+ def forward(self, vision_outputs: torch.Tensor):
683
+ batch_size, _, hidden_size = vision_outputs.shape
684
+
685
+ reshaped_vision_outputs = vision_outputs.transpose(1, 2)
686
+ reshaped_vision_outputs = reshaped_vision_outputs.reshape(
687
+ batch_size, hidden_size, self.patches_per_image, self.patches_per_image
688
+ )
689
+ reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
690
+
691
+ pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
692
+ pooled_vision_outputs = pooled_vision_outputs.flatten(2)
693
+ pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
694
+
695
+ normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
696
+
697
+ projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
698
+ return projected_vision_outputs.type_as(vision_outputs)
699
+
700
+
701
+ def get_block_sequence_ids_for_mask(token_type_ids: torch.Tensor, device: torch.device | None = None) -> torch.Tensor:
702
+ # First find where a new image block starts: 1 if image and previous not image
703
+ # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
704
+ is_image = (token_type_ids == 1).to(device=device)
705
+ is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
706
+ new_image_start = is_image & ~is_previous_image
707
+ group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
708
+ block_sequence_ids = torch.where(is_image, group_ids, -1)
709
+ return block_sequence_ids
710
+
711
+
712
+ @auto_docstring(
713
+ custom_intro="""
714
+ The Base Gemma3 model which consists of a vision backbone and a language model without language modeling head.,
715
+ """
716
+ )
717
+ class Gemma3Model(Gemma3PreTrainedModel):
718
+ # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
719
+ accepts_loss_kwargs = False
720
+
721
+ def __init__(self, config: Gemma3Config):
722
+ super().__init__(config)
723
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
724
+ self.multi_modal_projector = Gemma3MultiModalProjector(config)
725
+ self.vocab_size = config.text_config.vocab_size
726
+
727
+ language_model = AutoModel.from_config(config=config.text_config)
728
+ self.language_model = language_model
729
+ self.post_init()
730
+
731
+ @can_return_tuple
732
+ @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
733
+ def get_image_features(
734
+ self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
735
+ ) -> tuple | BaseModelOutputWithPooling:
736
+ vision_outputs = self.vision_tower(pixel_values=pixel_values, return_dict=True, **kwargs)
737
+ last_hidden_state = vision_outputs.last_hidden_state
738
+ vision_outputs.pooler_output = self.multi_modal_projector(last_hidden_state)
739
+
740
+ return vision_outputs
741
+
742
+ def get_placeholder_mask(
743
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
744
+ ):
745
+ """
746
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
747
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
748
+ """
749
+ if input_ids is None:
750
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
751
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
752
+ )
753
+ special_image_mask = special_image_mask.all(-1)
754
+ else:
755
+ special_image_mask = input_ids == self.config.image_token_id
756
+
757
+ n_image_tokens = special_image_mask.sum()
758
+ n_image_features = image_features.shape[0] * image_features.shape[1]
759
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
760
+ torch_compilable_check(
761
+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
762
+ f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
763
+ )
764
+ return special_image_mask
765
+
766
+ @can_return_tuple
767
+ @auto_docstring
768
+ def forward(
769
+ self,
770
+ input_ids: torch.LongTensor | None = None,
771
+ pixel_values: torch.FloatTensor | None = None,
772
+ attention_mask: torch.Tensor | None = None,
773
+ position_ids: torch.LongTensor | None = None,
774
+ past_key_values: Cache | None = None,
775
+ token_type_ids: torch.LongTensor | None = None,
776
+ inputs_embeds: torch.FloatTensor | None = None,
777
+ labels: torch.LongTensor | None = None,
778
+ use_cache: bool | None = None,
779
+ **lm_kwargs: Unpack[TransformersKwargs],
780
+ ) -> tuple | Gemma3ModelOutputWithPast:
781
+ r"""
782
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
783
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
784
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
785
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
786
+
787
+ Example:
788
+
789
+ ```python
790
+ >>> from PIL import Image
791
+ >>> import httpx
792
+ >>> from io import BytesIO
793
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
794
+
795
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma32-3b-mix-224")
796
+ >>> processor = AutoProcessor.from_pretrained("google/gemma32-3b-mix-224")
797
+
798
+ >>> prompt = "Where is the cat standing?"
799
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
800
+ >>> with httpx.stream("GET", url) as response:
801
+ ... image = Image.open(BytesIO(response.read()))
802
+
803
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
804
+
805
+ >>> # Generate
806
+ >>> generate_ids = model.generate(**inputs,)
807
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
808
+ "Where is the cat standing?\nsnow"
809
+ ```"""
810
+ if (input_ids is None) ^ (inputs_embeds is not None):
811
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
812
+
813
+ # Replace image id with PAD if the image token if OOV, to avoid index-errors
814
+ if input_ids is not None and self.config.image_token_id >= self.vocab_size:
815
+ special_image_mask = input_ids == self.config.image_token_id
816
+ llm_input_ids = input_ids.clone()
817
+ llm_input_ids[special_image_mask] = 0
818
+ else:
819
+ llm_input_ids = input_ids
820
+
821
+ if inputs_embeds is None:
822
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
823
+
824
+ # Merge text and images
825
+ if pixel_values is not None:
826
+ image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
827
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
828
+ special_image_mask = self.get_placeholder_mask(
829
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
830
+ )
831
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
832
+
833
+ # It may already have been prepared by e.g. `generate`
834
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
835
+ mask_kwargs = {
836
+ "config": self.config.get_text_config(),
837
+ "inputs_embeds": inputs_embeds,
838
+ "attention_mask": attention_mask,
839
+ "past_key_values": past_key_values,
840
+ "position_ids": position_ids,
841
+ }
842
+
843
+ if token_type_ids is not None:
844
+ mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask(
845
+ token_type_ids, device=inputs_embeds.device
846
+ )
847
+
848
+ # Create the masks
849
+ sliding_mask_kwargs = mask_kwargs.copy()
850
+ causal_mask_mapping = {
851
+ "full_attention": create_causal_mask(**mask_kwargs),
852
+ "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
853
+ }
854
+
855
+ outputs = self.language_model(
856
+ attention_mask=causal_mask_mapping,
857
+ position_ids=position_ids,
858
+ past_key_values=past_key_values,
859
+ inputs_embeds=inputs_embeds,
860
+ use_cache=use_cache,
861
+ return_dict=True,
862
+ **lm_kwargs,
863
+ )
864
+
865
+ return Gemma3ModelOutputWithPast(
866
+ last_hidden_state=outputs.last_hidden_state,
867
+ past_key_values=outputs.past_key_values,
868
+ hidden_states=outputs.hidden_states,
869
+ attentions=outputs.attentions,
870
+ image_hidden_states=image_features if pixel_values is not None else None,
871
+ )
872
+
873
+
874
+ @auto_docstring(
875
+ custom_intro="""
876
+ The Base Gemma3 model which consists of a vision backbone and a language model without language modeling head.,
877
+ """
878
+ )
879
+ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
880
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
881
+ # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
882
+ # Fix: https://github.com/huggingface/transformers/issues/40564
883
+ accepts_loss_kwargs = False
884
+
885
+ def __init__(self, config: Gemma3Config):
886
+ super().__init__(config)
887
+ self.model = Gemma3Model(config)
888
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
889
+ self.post_init()
890
+
891
+ @auto_docstring
892
+ def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]):
893
+ return self.model.get_image_features(pixel_values, **kwargs)
894
+
895
+ @can_return_tuple
896
+ @auto_docstring
897
+ def forward(
898
+ self,
899
+ input_ids: torch.LongTensor | None = None,
900
+ pixel_values: torch.FloatTensor | None = None,
901
+ attention_mask: torch.Tensor | None = None,
902
+ position_ids: torch.LongTensor | None = None,
903
+ past_key_values: Cache | None = None,
904
+ token_type_ids: torch.LongTensor | None = None,
905
+ inputs_embeds: torch.FloatTensor | None = None,
906
+ labels: torch.LongTensor | None = None,
907
+ use_cache: bool | None = None,
908
+ logits_to_keep: int | torch.Tensor = 0,
909
+ **lm_kwargs: Unpack[TransformersKwargs],
910
+ ) -> tuple | Gemma3CausalLMOutputWithPast:
911
+ r"""
912
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
913
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
914
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
915
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
916
+
917
+ Example:
918
+
919
+ ```python
920
+ >>> from PIL import Image
921
+ >>> import httpx
922
+ >>> from io import BytesIO
923
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
924
+
925
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
926
+ >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
927
+
928
+ >>> messages = [
929
+ ... {
930
+ ... "role": "system",
931
+ ... "content": [
932
+ ... {"type": "text", "text": "You are a helpful assistant."}
933
+ ... ]
934
+ ... },
935
+ ... {
936
+ ... "role": "user", "content": [
937
+ ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
938
+ ... {"type": "text", "text": "Where is the cat standing?"},
939
+ ... ]
940
+ ... },
941
+ ... ]
942
+
943
+ >>> inputs = processor.apply_chat_template(
944
+ ... messages,
945
+ ... tokenize=True,
946
+ ... return_dict=True,
947
+ ... return_tensors="pt",
948
+ ... add_generation_prompt=True
949
+ ... )
950
+ >>> # Generate
951
+ >>> generate_ids = model.generate(**inputs)
952
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
953
+ "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
954
+ ```
955
+ """
956
+ outputs = self.model(
957
+ input_ids=input_ids,
958
+ pixel_values=pixel_values,
959
+ token_type_ids=token_type_ids,
960
+ attention_mask=attention_mask,
961
+ position_ids=position_ids,
962
+ past_key_values=past_key_values,
963
+ inputs_embeds=inputs_embeds,
964
+ use_cache=use_cache,
965
+ labels=labels,
966
+ return_dict=True,
967
+ **lm_kwargs,
968
+ )
969
+
970
+ hidden_states = outputs[0]
971
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
972
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
973
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
974
+
975
+ loss = None
976
+ if labels is not None:
977
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
978
+ logits = logits.float()
979
+ shift_logits = logits[..., :-1, :]
980
+ shift_labels = labels[..., 1:]
981
+ if attention_mask is not None:
982
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
983
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
984
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
985
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
986
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
987
+ else:
988
+ shift_logits = shift_logits.contiguous()
989
+ shift_labels = shift_labels.contiguous()
990
+ # Flatten the tokens
991
+ loss_fct = nn.CrossEntropyLoss()
992
+
993
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
994
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
995
+ loss = loss_fct(flat_logits, flat_labels)
996
+
997
+ return Gemma3CausalLMOutputWithPast(
998
+ loss=loss,
999
+ logits=logits,
1000
+ past_key_values=outputs.past_key_values,
1001
+ hidden_states=outputs.hidden_states,
1002
+ attentions=outputs.attentions,
1003
+ image_hidden_states=outputs.image_hidden_states,
1004
+ )
1005
+
1006
+ def prepare_inputs_for_generation(
1007
+ self,
1008
+ input_ids,
1009
+ past_key_values=None,
1010
+ inputs_embeds=None,
1011
+ position_ids=None,
1012
+ pixel_values=None,
1013
+ attention_mask=None,
1014
+ token_type_ids=None,
1015
+ use_cache=True,
1016
+ logits_to_keep=None,
1017
+ labels=None,
1018
+ is_first_iteration=False,
1019
+ **kwargs,
1020
+ ):
1021
+ # Overwritten -- custom `pixel_values` handling
1022
+ model_inputs = super().prepare_inputs_for_generation(
1023
+ input_ids,
1024
+ past_key_values=past_key_values,
1025
+ inputs_embeds=inputs_embeds,
1026
+ attention_mask=attention_mask,
1027
+ position_ids=position_ids,
1028
+ use_cache=use_cache,
1029
+ logits_to_keep=logits_to_keep,
1030
+ token_type_ids=token_type_ids,
1031
+ is_first_iteration=is_first_iteration,
1032
+ **kwargs,
1033
+ )
1034
+
1035
+ # Pixel values are used only in the first iteration if available
1036
+ # In subsequent iterations, they are already merged with text and cached
1037
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
1038
+ # iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
1039
+ if is_first_iteration or not use_cache:
1040
+ model_inputs["pixel_values"] = pixel_values
1041
+ else:
1042
+ # Don't pass to not apply bidirectional mask on top
1043
+ model_inputs["token_type_ids"] = None
1044
+
1045
+ return model_inputs
1046
+
1047
+ @staticmethod
1048
+ def create_masks_for_generate(
1049
+ config: PreTrainedConfig,
1050
+ inputs_embeds: torch.Tensor,
1051
+ attention_mask: torch.Tensor | None,
1052
+ past_key_values: Cache | None,
1053
+ position_ids: torch.Tensor | None,
1054
+ token_type_ids: torch.Tensor | None = None,
1055
+ is_first_iteration: bool | None = False,
1056
+ **kwargs,
1057
+ ) -> dict:
1058
+ mask_kwargs = {
1059
+ "config": config.get_text_config(),
1060
+ "inputs_embeds": inputs_embeds,
1061
+ "attention_mask": attention_mask,
1062
+ "past_key_values": past_key_values,
1063
+ "position_ids": position_ids,
1064
+ }
1065
+
1066
+ if token_type_ids is not None:
1067
+ mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask(
1068
+ token_type_ids, device=inputs_embeds.device
1069
+ )
1070
+
1071
+ return create_masks_for_generate(**mask_kwargs)
1072
+
1073
+
1074
+ @auto_docstring(
1075
+ custom_intro="""
1076
+ Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig.
1077
+ It uses the generic sequence classification implementation for efficiency and consistency."""
1078
+ )
1079
+ class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
1080
+ config: Gemma3TextConfig
1081
+ input_modalities = ("text",)
1082
+
1083
+
1084
+ class Gemma3ForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
1085
+ def forward(
1086
+ self,
1087
+ input_ids: torch.LongTensor | None = None,
1088
+ pixel_values: torch.FloatTensor | None = None,
1089
+ attention_mask: torch.Tensor | None = None,
1090
+ position_ids: torch.LongTensor | None = None,
1091
+ past_key_values: Cache | None = None,
1092
+ token_type_ids: torch.LongTensor | None = None,
1093
+ inputs_embeds: torch.FloatTensor | None = None,
1094
+ labels: torch.LongTensor | None = None,
1095
+ **kwargs: Unpack[TransformersKwargs],
1096
+ ) -> SequenceClassifierOutputWithPast:
1097
+ return super().forward(
1098
+ input_ids=input_ids,
1099
+ attention_mask=attention_mask,
1100
+ position_ids=position_ids,
1101
+ past_key_values=past_key_values,
1102
+ inputs_embeds=inputs_embeds,
1103
+ pixel_values=pixel_values,
1104
+ token_type_ids=token_type_ids,
1105
+ labels=labels,
1106
+ **kwargs,
1107
+ )
1108
+
1109
+
1110
+ __all__ = [
1111
+ "Gemma3PreTrainedModel",
1112
+ "Gemma3TextModel",
1113
+ "Gemma3ForCausalLM",
1114
+ "Gemma3ForConditionalGeneration",
1115
+ "Gemma3Model",
1116
+ "Gemma3ForSequenceClassification",
1117
+ "Gemma3TextForSequenceClassification",
1118
+ ]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/modular_gemma3.py ADDED
@@ -0,0 +1,941 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from collections.abc import Callable
16
+ from typing import Any, Optional
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from huggingface_hub.dataclasses import strict
21
+
22
+ from ... import initialization as init
23
+ from ...cache_utils import Cache, DynamicCache
24
+ from ...configuration_utils import PreTrainedConfig
25
+ from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
26
+ from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
27
+ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, SequenceClassifierOutputWithPast
28
+ from ...modeling_rope_utils import (
29
+ ROPE_INIT_FUNCTIONS,
30
+ dynamic_rope_update,
31
+ )
32
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
33
+ from ...processing_utils import Unpack
34
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
35
+ from ...utils.generic import maybe_autocast
36
+ from ..gemma2.configuration_gemma2 import Gemma2Config
37
+ from ..gemma2.modeling_gemma2 import (
38
+ Gemma2Attention,
39
+ Gemma2ForCausalLM,
40
+ Gemma2MLP,
41
+ Gemma2Model,
42
+ Gemma2PreTrainedModel,
43
+ Gemma2RMSNorm,
44
+ Gemma2RotaryEmbedding,
45
+ apply_rotary_pos_emb,
46
+ eager_attention_forward,
47
+ )
48
+ from ..paligemma.modeling_paligemma import (
49
+ PaliGemmaCausalLMOutputWithPast,
50
+ PaliGemmaForConditionalGeneration,
51
+ PaliGemmaModel,
52
+ PaligemmaModelOutputWithPast,
53
+ )
54
+ from ..siglip import SiglipVisionConfig
55
+
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+
60
+ @auto_docstring(checkpoint="google/gemma-3-4b-it")
61
+ @strict
62
+ class Gemma3TextConfig(Gemma2Config, PreTrainedConfig):
63
+ r"""
64
+ query_pre_attn_scalar (`float`, *optional*, defaults to 256):
65
+ scaling factor used on the attention scores
66
+ final_logit_softcapping (`float`, *optional*):
67
+ Scaling factor when applying tanh softcapping on the logits.
68
+ attn_logit_softcapping (`float`, *optional*):
69
+ Scaling factor when applying tanh softcapping on the attention scores.
70
+ use_bidirectional_attention (`bool`, *optional*, defaults to `False`):
71
+ If True, the model will attend to all text tokens instead of using a causal mask. This does not change
72
+ behavior for vision tokens.
73
+
74
+ ```python
75
+ >>> from transformers import Gemma3TextModel, Gemma3TextConfig
76
+ >>> # Initializing a Gemma3Text gemma3_text-7b style configuration
77
+ >>> configuration = Gemma3TextConfig()
78
+ >>> # Initializing a model from the gemma3_text-7b style configuration
79
+ >>> model = Gemma3TextModel(configuration)
80
+ >>> # Accessing the model configuration
81
+ >>> configuration = model.config
82
+ ```
83
+ """
84
+
85
+ model_type = "gemma3_text"
86
+ base_model_tp_plan = {
87
+ "layers.*.self_attn.q_proj": "colwise",
88
+ "layers.*.self_attn.k_proj": "colwise",
89
+ "layers.*.self_attn.v_proj": "colwise",
90
+ "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
91
+ "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
92
+ "layers.*.self_attn.o_proj": "rowwise",
93
+ "layers.*.mlp.gate_proj": "colwise",
94
+ "layers.*.mlp.up_proj": "colwise",
95
+ "layers.*.mlp.down_proj": "rowwise",
96
+ }
97
+ default_theta = {"global": 1_000_000.0, "local": 10_000.0}
98
+
99
+ vocab_size: int = 262_208
100
+ max_position_embeddings: int = 131_072
101
+ layer_types: list[str] | None = None
102
+ final_logit_softcapping: float | None = None
103
+ attn_logit_softcapping: float | None = None
104
+ rope_parameters: dict | None = None
105
+ use_bidirectional_attention: bool | None = False
106
+
107
+ def __post_init__(self, **kwargs):
108
+ if self.use_bidirectional_attention:
109
+ self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
110
+
111
+ # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
112
+ self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
113
+
114
+ if self.layer_types is None:
115
+ self.layer_types = [
116
+ "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
117
+ for i in range(self.num_hidden_layers)
118
+ ]
119
+
120
+ PreTrainedConfig.__post_init__(**kwargs)
121
+
122
+ def convert_rope_params_to_dict(self, **kwargs):
123
+ rope_scaling = kwargs.pop("rope_scaling", None)
124
+
125
+ # Try to set `rope_scaling` if available, otherwise use `rope_parameters`. If we find `rope_parameters`
126
+ # as arg in the inputs, we can safely assume that it is in the new format. New naming used -> new format
127
+ default_rope_params = {
128
+ "sliding_attention": {"rope_type": "default"},
129
+ "full_attention": {"rope_type": "default"},
130
+ }
131
+ self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else default_rope_params
132
+ if rope_scaling is not None:
133
+ self.rope_parameters["full_attention"].update(rope_scaling)
134
+
135
+ # Set default values if not present
136
+ if self.rope_parameters.get("full_attention") is None:
137
+ self.rope_parameters["full_attention"] = {"rope_type": "default"}
138
+ self.rope_parameters["full_attention"].setdefault(
139
+ "rope_theta", kwargs.pop("rope_theta", self.default_theta["global"])
140
+ )
141
+ if self.rope_parameters.get("sliding_attention") is None:
142
+ self.rope_parameters["sliding_attention"] = {"rope_type": "default"}
143
+ self.rope_parameters["sliding_attention"].setdefault(
144
+ "rope_theta", kwargs.pop("rope_local_base_freq", self.default_theta["local"])
145
+ )
146
+
147
+ # Standardize and validate the correctness of rotary position embeddings parameters
148
+ self.standardize_rope_params()
149
+ return kwargs
150
+
151
+
152
+ @auto_docstring(checkpoint="google/gemma-3-4b-it")
153
+ @strict
154
+ class Gemma3Config(PreTrainedConfig):
155
+ r"""
156
+ mm_tokens_per_image (`int`, *optional*, defaults to 256):
157
+ The number of tokens per image embedding.
158
+ boi_token_index (`int`, *optional*, defaults to 255999):
159
+ The begin-of-image token index to wrap the image prompt.
160
+ eoi_token_index (`int`, *optional*, defaults to 256000):
161
+ The end-of-image token index to wrap the image prompt.
162
+
163
+ Example:
164
+
165
+ ```python
166
+ >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
167
+
168
+ >>> # Initializing a Siglip-like vision config
169
+ >>> vision_config = SiglipVisionConfig()
170
+
171
+ >>> # Initializing a Gemma3 Text config
172
+ >>> text_config = Gemma3TextConfig()
173
+
174
+ >>> # Initializing a Gemma3 gemma-3-4b style configuration
175
+ >>> configuration = Gemma3Config(vision_config, text_config)
176
+
177
+ >>> # Initializing a model from the gemma-3-4b style configuration
178
+ >>> model = Gemma3TextConfig(configuration)
179
+
180
+ >>> # Accessing the model configuration
181
+ >>> configuration = model.config
182
+ ```"""
183
+
184
+ model_type = "gemma3"
185
+ attribute_map = {
186
+ "image_token_id": "image_token_index",
187
+ "boi_token_id": "boi_token_index",
188
+ "eoi_token_id": "eoi_token_index",
189
+ }
190
+ sub_configs = {
191
+ "text_config": Gemma3TextConfig,
192
+ "vision_config": SiglipVisionConfig,
193
+ }
194
+
195
+ text_config: Gemma3TextConfig | dict[str, Any] | None = None
196
+ vision_config: SiglipVisionConfig | dict[str, Any] | None = None
197
+ mm_tokens_per_image: int | None = 256
198
+ boi_token_index: int | None = 255_999
199
+ eoi_token_index: int | None = 256_000
200
+ image_token_index: int | None = 262_144
201
+ initializer_range: float | None = 0.02
202
+ tie_word_embeddings: bool | None = True
203
+
204
+ def __post_init__(self, **kwargs):
205
+ if self.text_config is None:
206
+ self.text_config = Gemma3TextConfig()
207
+ logger.info("text_config is None, using default Gemma3TextConfig text config.")
208
+ elif isinstance(self.text_config, dict):
209
+ self.text_config = Gemma3TextConfig(**self.text_config)
210
+
211
+ if isinstance(self.vision_config, dict):
212
+ self.vision_config = SiglipVisionConfig(**self.vision_config)
213
+ elif self.vision_config is None:
214
+ self.vision_config = SiglipVisionConfig()
215
+ logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
216
+
217
+ super().__post_init__(**kwargs)
218
+
219
+
220
+ class Gemma3ModelOutputWithPast(PaligemmaModelOutputWithPast):
221
+ pass
222
+
223
+
224
+ class Gemma3CausalLMOutputWithPast(PaliGemmaCausalLMOutputWithPast):
225
+ pass
226
+
227
+
228
+ class Gemma3TextScaledWordEmbedding(nn.Embedding):
229
+ """
230
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
231
+ """
232
+
233
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
234
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
235
+ self.scalar_embed_scale = embed_scale
236
+ self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
237
+
238
+ def forward(self, input_ids: torch.Tensor):
239
+ return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
240
+
241
+
242
+ class Gemma3MLP(Gemma2MLP):
243
+ def __init__(self, config: Gemma3TextConfig):
244
+ super().__init__(config)
245
+
246
+
247
+ class Gemma3RMSNorm(Gemma2RMSNorm):
248
+ def __init__(self, dim: int, eps: float = 1e-6):
249
+ super().__init__(dim=dim, eps=eps)
250
+
251
+
252
+ class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding, nn.Module):
253
+ def __init__(self, config: Gemma3TextConfig):
254
+ nn.Module.__init__(self)
255
+ self.max_seq_len_cached = config.max_position_embeddings
256
+ self.original_max_seq_len = config.max_position_embeddings
257
+ self.config = config
258
+ self.layer_types = list(set(config.layer_types))
259
+ self.rope_type = {}
260
+ for layer_type in self.layer_types:
261
+ rope_params = self.config.rope_parameters[layer_type]
262
+ if rope_params is None:
263
+ continue
264
+
265
+ self.rope_type[layer_type] = rope_params["rope_type"]
266
+ rope_init_fn: Callable = self.compute_default_rope_parameters
267
+ if self.rope_type[layer_type] != "default":
268
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
269
+ curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, layer_type=layer_type)
270
+ self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
271
+ self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
272
+ setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
273
+
274
+ @staticmethod
275
+ def compute_default_rope_parameters(
276
+ config: Gemma3TextConfig | None = None,
277
+ device: Optional["torch.device"] = None,
278
+ seq_len: int | None = None,
279
+ layer_type: str | None = None,
280
+ ) -> tuple["torch.Tensor", float]:
281
+ """
282
+ Computes the inverse frequencies according to the original RoPE implementation
283
+ Args:
284
+ config ([`~transformers.PreTrainedConfig`]):
285
+ The model configuration.
286
+ device (`torch.device`):
287
+ The device to use for initialization of the inverse frequencies.
288
+ seq_len (`int`, *optional*):
289
+ The current sequence length. Unused for this type of RoPE.
290
+ layer_type (`str`, *optional*):
291
+ The current layer type if the model has different RoPE parameters per type.
292
+ Should not be used unless `config.layer_types is not None`
293
+
294
+ Returns:
295
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
296
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
297
+ """
298
+ # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
299
+ base = config.rope_parameters[layer_type]["rope_theta"]
300
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
301
+
302
+ attention_factor = 1.0 # Unused in this type of RoPE
303
+
304
+ # Compute the inverse frequencies
305
+ inv_freq = 1.0 / (
306
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
307
+ )
308
+ return inv_freq, attention_factor
309
+
310
+ @torch.no_grad()
311
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
312
+ def forward(self, x, position_ids, layer_type=None):
313
+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
314
+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
315
+
316
+ inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
317
+ position_ids_expanded = position_ids[:, None, :].float()
318
+
319
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
320
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
321
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
322
+ emb = torch.cat((freqs, freqs), dim=-1)
323
+ cos = emb.cos() * attention_scaling
324
+ sin = emb.sin() * attention_scaling
325
+
326
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
327
+
328
+
329
+ # Weird way to inherit but otherwise the sliding window gets defined first and can't access `is_sliding`
330
+ class Gemma3Attention(Gemma2Attention):
331
+ def __init__(self, config: Gemma3TextConfig, layer_idx: int):
332
+ super().__init__(config, layer_idx)
333
+ self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
334
+ self.is_sliding = self.layer_type == "sliding_attention"
335
+ self.is_causal = not self.config.use_bidirectional_attention
336
+
337
+ self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
338
+ self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
339
+
340
+ def forward(
341
+ self,
342
+ hidden_states: torch.Tensor,
343
+ position_embeddings: torch.Tensor = None,
344
+ attention_mask: torch.Tensor | None = None,
345
+ past_key_values: Cache | None = None,
346
+ **kwargs: Unpack[TransformersKwargs],
347
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
348
+ input_shape = hidden_states.shape[:-1]
349
+ hidden_shape = (*input_shape, -1, self.head_dim)
350
+
351
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
352
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
353
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
354
+
355
+ query_states = self.q_norm(query_states)
356
+ key_states = self.k_norm(key_states)
357
+
358
+ cos, sin = position_embeddings
359
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
360
+
361
+ if past_key_values is not None:
362
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
363
+
364
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
365
+ self.config._attn_implementation, eager_attention_forward
366
+ )
367
+
368
+ attn_output, attn_weights = attention_interface(
369
+ self,
370
+ query_states,
371
+ key_states,
372
+ value_states,
373
+ attention_mask,
374
+ dropout=self.attention_dropout if self.training else 0.0,
375
+ scaling=self.scaling,
376
+ sliding_window=self.sliding_window,
377
+ **kwargs,
378
+ )
379
+
380
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
381
+ attn_output = self.o_proj(attn_output)
382
+ return attn_output, attn_weights
383
+
384
+
385
+ class Gemma3DecoderLayer(GradientCheckpointingLayer):
386
+ def __init__(self, config: Gemma3TextConfig, layer_idx: int):
387
+ super().__init__()
388
+ self.config = config
389
+ self.hidden_size = config.hidden_size
390
+ self.layer_idx = layer_idx
391
+ self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
392
+ self.mlp = Gemma3MLP(config)
393
+ self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
394
+ self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
395
+ self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
396
+ self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
397
+
398
+ def forward(
399
+ self,
400
+ hidden_states: torch.Tensor,
401
+ position_embeddings: torch.Tensor = None,
402
+ attention_mask: torch.Tensor | None = None,
403
+ position_ids: torch.LongTensor | None = None,
404
+ past_key_values: Cache | None = None,
405
+ **kwargs: Unpack[TransformersKwargs],
406
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
407
+ residual = hidden_states
408
+
409
+ hidden_states = self.input_layernorm(hidden_states)
410
+
411
+ hidden_states, _ = self.self_attn(
412
+ hidden_states=hidden_states,
413
+ position_embeddings=position_embeddings,
414
+ attention_mask=attention_mask,
415
+ position_ids=position_ids,
416
+ past_key_values=past_key_values,
417
+ **kwargs,
418
+ )
419
+ hidden_states = self.post_attention_layernorm(hidden_states)
420
+ hidden_states = residual + hidden_states
421
+
422
+ residual = hidden_states
423
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
424
+ hidden_states = self.mlp(hidden_states)
425
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
426
+ hidden_states = residual + hidden_states
427
+
428
+ return hidden_states
429
+
430
+
431
+ GEMMA3_START_DOCSTRING = None
432
+
433
+
434
+ class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
435
+ base_model_prefix = "model"
436
+ input_modalities = ("image", "text")
437
+ _no_split_modules = [
438
+ "Gemma3DecoderLayer",
439
+ "SiglipVisionEmbeddings",
440
+ "SiglipEncoderLayer",
441
+ "SiglipMultiheadAttentionPoolingHead",
442
+ ]
443
+
444
+ @torch.no_grad()
445
+ def _init_weights(self, module):
446
+ PreTrainedModel._init_weights(self, module)
447
+ if isinstance(module, Gemma3MultiModalProjector):
448
+ init.zeros_(module.mm_input_projection_weight)
449
+ # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
450
+ elif "RMSNorm" in module.__class__.__name__:
451
+ init.zeros_(module.weight)
452
+ elif isinstance(module, Gemma3TextScaledWordEmbedding):
453
+ init.constant_(module.embed_scale, module.scalar_embed_scale)
454
+ elif isinstance(module, Gemma3RotaryEmbedding):
455
+ for layer_type in module.layer_types:
456
+ rope_init_fn = module.compute_default_rope_parameters
457
+ if module.rope_type[layer_type] != "default":
458
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
459
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
460
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
461
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
462
+
463
+
464
+ def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
465
+ """
466
+ Enables a bidirectional mask within the sliding window.
467
+ """
468
+
469
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
470
+ """A token can attend to any other token if their absolute distance is within
471
+ the (exclusive) sliding window size (distance < sliding_window)."""
472
+ return abs(q_idx - kv_idx) < sliding_window
473
+
474
+ return inner_mask
475
+
476
+
477
+ class Gemma3TextModel(Gemma2Model):
478
+ config: Gemma3TextConfig
479
+ input_modalities = ("text",)
480
+
481
+ def __init__(self, config: Gemma3TextConfig):
482
+ super().__init__(config)
483
+
484
+ # Gemma3 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
485
+ self.embed_tokens = Gemma3TextScaledWordEmbedding(
486
+ config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
487
+ )
488
+
489
+ def forward(
490
+ self,
491
+ input_ids: torch.LongTensor | None = None,
492
+ attention_mask: torch.Tensor | None = None,
493
+ position_ids: torch.LongTensor | None = None,
494
+ past_key_values: Cache | None = None,
495
+ inputs_embeds: torch.FloatTensor | None = None,
496
+ use_cache: bool | None = None,
497
+ **kwargs: Unpack[TransformersKwargs],
498
+ ) -> BaseModelOutputWithPast:
499
+ if (input_ids is None) ^ (inputs_embeds is not None):
500
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
501
+
502
+ if inputs_embeds is None:
503
+ inputs_embeds = self.embed_tokens(input_ids)
504
+
505
+ if use_cache and past_key_values is None:
506
+ past_key_values = DynamicCache(config=self.config)
507
+
508
+ if position_ids is None:
509
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
510
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
511
+ position_ids = position_ids.unsqueeze(0)
512
+
513
+ # It may already have been prepared by e.g. `generate`
514
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
515
+ # Prepare mask arguments
516
+ mask_kwargs = {
517
+ "config": self.config,
518
+ "inputs_embeds": inputs_embeds,
519
+ "attention_mask": attention_mask,
520
+ "past_key_values": past_key_values,
521
+ "position_ids": position_ids,
522
+ }
523
+ sliding_mask_kwargs = mask_kwargs.copy()
524
+
525
+ if self.config.use_bidirectional_attention:
526
+ mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
527
+ sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window)
528
+
529
+ # Create the masks
530
+ causal_mask_mapping = {
531
+ "full_attention": create_causal_mask(**mask_kwargs),
532
+ "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
533
+ }
534
+
535
+ # embed positions
536
+ hidden_states = inputs_embeds
537
+ position_embeddings = {}
538
+ for layer_type in set(self.config.layer_types):
539
+ position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
540
+
541
+ for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
542
+ hidden_states = decoder_layer(
543
+ hidden_states,
544
+ attention_mask=causal_mask_mapping[self.config.layer_types[i]],
545
+ position_embeddings=position_embeddings[self.config.layer_types[i]],
546
+ position_ids=position_ids,
547
+ past_key_values=past_key_values,
548
+ **kwargs,
549
+ )
550
+
551
+ hidden_states = self.norm(hidden_states)
552
+
553
+ return BaseModelOutputWithPast(
554
+ last_hidden_state=hidden_states,
555
+ past_key_values=past_key_values,
556
+ )
557
+
558
+
559
+ class Gemma3ForCausalLM(Gemma2ForCausalLM):
560
+ config: Gemma3TextConfig
561
+
562
+ def __init__(self, config: Gemma3TextConfig):
563
+ super().__init__(config)
564
+ self.model = Gemma3TextModel(config)
565
+
566
+
567
+ class Gemma3MultiModalProjector(nn.Module):
568
+ def __init__(self, config: Gemma3Config):
569
+ super().__init__()
570
+
571
+ self.mm_input_projection_weight = nn.Parameter(
572
+ torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
573
+ )
574
+
575
+ self.mm_soft_emb_norm = Gemma3RMSNorm(
576
+ config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
577
+ )
578
+
579
+ self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
580
+ self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
581
+ self.kernel_size = self.patches_per_image // self.tokens_per_side
582
+ self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
583
+
584
+ def forward(self, vision_outputs: torch.Tensor):
585
+ batch_size, _, hidden_size = vision_outputs.shape
586
+
587
+ reshaped_vision_outputs = vision_outputs.transpose(1, 2)
588
+ reshaped_vision_outputs = reshaped_vision_outputs.reshape(
589
+ batch_size, hidden_size, self.patches_per_image, self.patches_per_image
590
+ )
591
+ reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
592
+
593
+ pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
594
+ pooled_vision_outputs = pooled_vision_outputs.flatten(2)
595
+ pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
596
+
597
+ normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
598
+
599
+ projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
600
+ return projected_vision_outputs.type_as(vision_outputs)
601
+
602
+
603
+ def get_block_sequence_ids_for_mask(token_type_ids: torch.Tensor, device: torch.device | None = None) -> torch.Tensor:
604
+ # First find where a new image block starts: 1 if image and previous not image
605
+ # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
606
+ is_image = (token_type_ids == 1).to(device=device)
607
+ is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
608
+ new_image_start = is_image & ~is_previous_image
609
+ group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
610
+ block_sequence_ids = torch.where(is_image, group_ids, -1)
611
+ return block_sequence_ids
612
+
613
+
614
+ class Gemma3Model(PaliGemmaModel):
615
+ # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
616
+ accepts_loss_kwargs = False
617
+
618
+ def __init__(self, config: Gemma3Config):
619
+ super().__init__(config)
620
+ del self.text_config_dtype
621
+
622
+ @can_return_tuple
623
+ @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
624
+ def get_image_features(
625
+ self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
626
+ ) -> tuple | BaseModelOutputWithPooling:
627
+ vision_outputs = self.vision_tower(pixel_values=pixel_values, return_dict=True, **kwargs)
628
+ last_hidden_state = vision_outputs.last_hidden_state
629
+ vision_outputs.pooler_output = self.multi_modal_projector(last_hidden_state)
630
+
631
+ return vision_outputs
632
+
633
+ @can_return_tuple
634
+ @auto_docstring
635
+ def forward(
636
+ self,
637
+ input_ids: torch.LongTensor | None = None,
638
+ pixel_values: torch.FloatTensor | None = None,
639
+ attention_mask: torch.Tensor | None = None,
640
+ position_ids: torch.LongTensor | None = None,
641
+ past_key_values: Cache | None = None,
642
+ token_type_ids: torch.LongTensor | None = None,
643
+ inputs_embeds: torch.FloatTensor | None = None,
644
+ labels: torch.LongTensor | None = None,
645
+ use_cache: bool | None = None,
646
+ **lm_kwargs: Unpack[TransformersKwargs],
647
+ ) -> tuple | Gemma3ModelOutputWithPast:
648
+ if (input_ids is None) ^ (inputs_embeds is not None):
649
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
650
+
651
+ # Replace image id with PAD if the image token if OOV, to avoid index-errors
652
+ if input_ids is not None and self.config.image_token_id >= self.vocab_size:
653
+ special_image_mask = input_ids == self.config.image_token_id
654
+ llm_input_ids = input_ids.clone()
655
+ llm_input_ids[special_image_mask] = 0
656
+ else:
657
+ llm_input_ids = input_ids
658
+
659
+ if inputs_embeds is None:
660
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
661
+
662
+ # Merge text and images
663
+ if pixel_values is not None:
664
+ image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
665
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
666
+ special_image_mask = self.get_placeholder_mask(
667
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
668
+ )
669
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
670
+
671
+ # It may already have been prepared by e.g. `generate`
672
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
673
+ mask_kwargs = {
674
+ "config": self.config.get_text_config(),
675
+ "inputs_embeds": inputs_embeds,
676
+ "attention_mask": attention_mask,
677
+ "past_key_values": past_key_values,
678
+ "position_ids": position_ids,
679
+ }
680
+
681
+ if token_type_ids is not None:
682
+ mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask(
683
+ token_type_ids, device=inputs_embeds.device
684
+ )
685
+
686
+ # Create the masks
687
+ sliding_mask_kwargs = mask_kwargs.copy()
688
+ causal_mask_mapping = {
689
+ "full_attention": create_causal_mask(**mask_kwargs),
690
+ "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
691
+ }
692
+
693
+ outputs = self.language_model(
694
+ attention_mask=causal_mask_mapping,
695
+ position_ids=position_ids,
696
+ past_key_values=past_key_values,
697
+ inputs_embeds=inputs_embeds,
698
+ use_cache=use_cache,
699
+ return_dict=True,
700
+ **lm_kwargs,
701
+ )
702
+
703
+ return Gemma3ModelOutputWithPast(
704
+ last_hidden_state=outputs.last_hidden_state,
705
+ past_key_values=outputs.past_key_values,
706
+ hidden_states=outputs.hidden_states,
707
+ attentions=outputs.attentions,
708
+ image_hidden_states=image_features if pixel_values is not None else None,
709
+ )
710
+
711
+
712
+ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
713
+ # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
714
+ # Fix: https://github.com/huggingface/transformers/issues/40564
715
+ accepts_loss_kwargs = False
716
+
717
+ @can_return_tuple
718
+ @auto_docstring
719
+ def forward(
720
+ self,
721
+ input_ids: torch.LongTensor | None = None,
722
+ pixel_values: torch.FloatTensor | None = None,
723
+ attention_mask: torch.Tensor | None = None,
724
+ position_ids: torch.LongTensor | None = None,
725
+ past_key_values: Cache | None = None,
726
+ token_type_ids: torch.LongTensor | None = None,
727
+ inputs_embeds: torch.FloatTensor | None = None,
728
+ labels: torch.LongTensor | None = None,
729
+ use_cache: bool | None = None,
730
+ logits_to_keep: int | torch.Tensor = 0,
731
+ **lm_kwargs: Unpack[TransformersKwargs],
732
+ ) -> tuple | Gemma3CausalLMOutputWithPast:
733
+ r"""
734
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
735
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
736
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
737
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
738
+
739
+ Example:
740
+
741
+ ```python
742
+ >>> from PIL import Image
743
+ >>> import httpx
744
+ >>> from io import BytesIO
745
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
746
+
747
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
748
+ >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
749
+
750
+ >>> messages = [
751
+ ... {
752
+ ... "role": "system",
753
+ ... "content": [
754
+ ... {"type": "text", "text": "You are a helpful assistant."}
755
+ ... ]
756
+ ... },
757
+ ... {
758
+ ... "role": "user", "content": [
759
+ ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
760
+ ... {"type": "text", "text": "Where is the cat standing?"},
761
+ ... ]
762
+ ... },
763
+ ... ]
764
+
765
+ >>> inputs = processor.apply_chat_template(
766
+ ... messages,
767
+ ... tokenize=True,
768
+ ... return_dict=True,
769
+ ... return_tensors="pt",
770
+ ... add_generation_prompt=True
771
+ ... )
772
+ >>> # Generate
773
+ >>> generate_ids = model.generate(**inputs)
774
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
775
+ "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
776
+ ```
777
+ """
778
+ outputs = self.model(
779
+ input_ids=input_ids,
780
+ pixel_values=pixel_values,
781
+ token_type_ids=token_type_ids,
782
+ attention_mask=attention_mask,
783
+ position_ids=position_ids,
784
+ past_key_values=past_key_values,
785
+ inputs_embeds=inputs_embeds,
786
+ use_cache=use_cache,
787
+ labels=labels,
788
+ return_dict=True,
789
+ **lm_kwargs,
790
+ )
791
+
792
+ hidden_states = outputs[0]
793
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
794
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
795
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
796
+
797
+ loss = None
798
+ if labels is not None:
799
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
800
+ logits = logits.float()
801
+ shift_logits = logits[..., :-1, :]
802
+ shift_labels = labels[..., 1:]
803
+ if attention_mask is not None:
804
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
805
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
806
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
807
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
808
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
809
+ else:
810
+ shift_logits = shift_logits.contiguous()
811
+ shift_labels = shift_labels.contiguous()
812
+ # Flatten the tokens
813
+ loss_fct = nn.CrossEntropyLoss()
814
+
815
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
816
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
817
+ loss = loss_fct(flat_logits, flat_labels)
818
+
819
+ return Gemma3CausalLMOutputWithPast(
820
+ loss=loss,
821
+ logits=logits,
822
+ past_key_values=outputs.past_key_values,
823
+ hidden_states=outputs.hidden_states,
824
+ attentions=outputs.attentions,
825
+ image_hidden_states=outputs.image_hidden_states,
826
+ )
827
+
828
+ def prepare_inputs_for_generation(
829
+ self,
830
+ input_ids,
831
+ past_key_values=None,
832
+ inputs_embeds=None,
833
+ position_ids=None,
834
+ pixel_values=None,
835
+ attention_mask=None,
836
+ token_type_ids=None,
837
+ use_cache=True,
838
+ logits_to_keep=None,
839
+ labels=None,
840
+ is_first_iteration=False,
841
+ **kwargs,
842
+ ):
843
+ # Overwritten -- custom `pixel_values` handling
844
+ model_inputs = super().prepare_inputs_for_generation(
845
+ input_ids,
846
+ past_key_values=past_key_values,
847
+ inputs_embeds=inputs_embeds,
848
+ attention_mask=attention_mask,
849
+ position_ids=position_ids,
850
+ use_cache=use_cache,
851
+ logits_to_keep=logits_to_keep,
852
+ token_type_ids=token_type_ids,
853
+ is_first_iteration=is_first_iteration,
854
+ **kwargs,
855
+ )
856
+
857
+ # Pixel values are used only in the first iteration if available
858
+ # In subsequent iterations, they are already merged with text and cached
859
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
860
+ # iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
861
+ if is_first_iteration or not use_cache:
862
+ model_inputs["pixel_values"] = pixel_values
863
+ else:
864
+ # Don't pass to not apply bidirectional mask on top
865
+ model_inputs["token_type_ids"] = None
866
+
867
+ return model_inputs
868
+
869
+ def create_masks_for_generate(
870
+ config: PreTrainedConfig,
871
+ inputs_embeds: torch.Tensor,
872
+ attention_mask: torch.Tensor | None,
873
+ past_key_values: Cache | None,
874
+ position_ids: torch.Tensor | None,
875
+ token_type_ids: torch.Tensor | None = None,
876
+ is_first_iteration: bool | None = False,
877
+ **kwargs,
878
+ ) -> dict:
879
+ mask_kwargs = {
880
+ "config": config.get_text_config(),
881
+ "inputs_embeds": inputs_embeds,
882
+ "attention_mask": attention_mask,
883
+ "past_key_values": past_key_values,
884
+ "position_ids": position_ids,
885
+ }
886
+
887
+ if token_type_ids is not None:
888
+ mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask(
889
+ token_type_ids, device=inputs_embeds.device
890
+ )
891
+
892
+ return create_masks_for_generate(**mask_kwargs)
893
+
894
+
895
+ @auto_docstring(
896
+ custom_intro="""
897
+ Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig.
898
+ It uses the generic sequence classification implementation for efficiency and consistency."""
899
+ )
900
+ class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
901
+ config: Gemma3TextConfig
902
+ input_modalities = ("text",)
903
+
904
+
905
+ class Gemma3ForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
906
+ def forward(
907
+ self,
908
+ input_ids: torch.LongTensor | None = None,
909
+ pixel_values: torch.FloatTensor | None = None,
910
+ attention_mask: torch.Tensor | None = None,
911
+ position_ids: torch.LongTensor | None = None,
912
+ past_key_values: Cache | None = None,
913
+ token_type_ids: torch.LongTensor | None = None,
914
+ inputs_embeds: torch.FloatTensor | None = None,
915
+ labels: torch.LongTensor | None = None,
916
+ **kwargs: Unpack[TransformersKwargs],
917
+ ) -> SequenceClassifierOutputWithPast:
918
+ return super().forward(
919
+ input_ids=input_ids,
920
+ attention_mask=attention_mask,
921
+ position_ids=position_ids,
922
+ past_key_values=past_key_values,
923
+ inputs_embeds=inputs_embeds,
924
+ pixel_values=pixel_values,
925
+ token_type_ids=token_type_ids,
926
+ labels=labels,
927
+ **kwargs,
928
+ )
929
+
930
+
931
+ __all__ = [
932
+ "Gemma3Config",
933
+ "Gemma3TextConfig",
934
+ "Gemma3PreTrainedModel",
935
+ "Gemma3TextModel",
936
+ "Gemma3ForCausalLM",
937
+ "Gemma3ForConditionalGeneration",
938
+ "Gemma3Model",
939
+ "Gemma3ForSequenceClassification",
940
+ "Gemma3TextForSequenceClassification",
941
+ ]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/gemma3/processing_gemma3.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import re
16
+
17
+ from ...feature_extraction_utils import BatchFeature
18
+ from ...image_utils import ImageInput, make_nested_list_of_images
19
+ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
20
+ from ...tokenization_utils_base import PreTokenizedInput, TextInput
21
+ from ...utils import auto_docstring, to_py_obj
22
+
23
+
24
+ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
25
+ _defaults = {
26
+ "text_kwargs": {
27
+ "padding": False,
28
+ "return_mm_token_type_ids": True,
29
+ },
30
+ "images_kwargs": {
31
+ "do_convert_rgb": True,
32
+ "do_pan_and_scan": False,
33
+ "pan_and_scan_min_crop_size": 256,
34
+ "pan_and_scan_max_num_crops": 4,
35
+ "pan_and_scan_min_ratio_to_activate": 1.2,
36
+ },
37
+ }
38
+
39
+
40
+ @auto_docstring
41
+ class Gemma3Processor(ProcessorMixin):
42
+ def __init__(
43
+ self,
44
+ image_processor,
45
+ tokenizer,
46
+ chat_template=None,
47
+ image_seq_length: int = 256,
48
+ **kwargs,
49
+ ):
50
+ self.image_seq_length = image_seq_length
51
+ self.image_token_id = tokenizer.image_token_id
52
+ self.boi_token = tokenizer.boi_token
53
+ self.image_token = tokenizer.image_token
54
+ image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
55
+ self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
56
+
57
+ super().__init__(
58
+ image_processor=image_processor,
59
+ tokenizer=tokenizer,
60
+ chat_template=chat_template,
61
+ **kwargs,
62
+ )
63
+
64
+ @auto_docstring
65
+ def __call__(
66
+ self,
67
+ images: ImageInput | None = None,
68
+ text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
69
+ **kwargs: Unpack[Gemma3ProcessorKwargs],
70
+ ) -> BatchFeature:
71
+ if text is None and images is None:
72
+ raise ValueError("Provide at least one of `text` or `images`.")
73
+
74
+ output_kwargs = self._merge_kwargs(
75
+ Gemma3ProcessorKwargs,
76
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
77
+ **kwargs,
78
+ )
79
+
80
+ if isinstance(text, str):
81
+ text = [text]
82
+ elif not isinstance(text, list) and not isinstance(text[0], str):
83
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
84
+
85
+ image_inputs = {}
86
+ if images is not None:
87
+ images = self.image_processor.fetch_images(images)
88
+ batched_images = make_nested_list_of_images(images)
89
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
90
+
91
+ # Create empty text to be replaced with placeholders
92
+ if not text:
93
+ text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
94
+
95
+ if len(batched_images) != len(text):
96
+ raise ValueError(
97
+ f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
98
+ )
99
+
100
+ # Replace image tokens by the full expanded sequence
101
+ num_crops = to_py_obj(image_inputs.pop("num_crops"))
102
+ batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
103
+ for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
104
+ image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
105
+
106
+ if len(images) != len(image_indexes):
107
+ raise ValueError(
108
+ f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images."
109
+ )
110
+
111
+ # Insert additional image tokens for Pan-and-Scan crops
112
+ for num, idx in reversed(list(zip(num_crops, image_indexes))):
113
+ if num:
114
+ formatted_image_text = (
115
+ f"Here is the original image {self.boi_token} and here are some crops to help you see better "
116
+ + " ".join([self.boi_token] * num)
117
+ )
118
+ prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
119
+ text[batch_idx] = prompt
120
+
121
+ # Expand placeholder image tokens to the full image token sequence
122
+ text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
123
+
124
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
125
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
126
+ text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
127
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
128
+
129
+ if return_mm_token_type_ids:
130
+ text_inputs["token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
131
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
132
+
133
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
134
+ """
135
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
136
+
137
+ Args:
138
+ image_sizes (`list[list[int]]`, *optional*):
139
+ The input sizes formatted as (height, width) per each image.
140
+
141
+ Returns:
142
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
143
+ input modalities, along with other useful data.
144
+ """
145
+
146
+ vision_data = {}
147
+ if image_sizes is not None:
148
+ # NOTE: no image cropping supported yet
149
+ num_image_tokens = [self.image_seq_length] * len(image_sizes)
150
+ num_image_patches = [1] * len(image_sizes)
151
+
152
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
153
+
154
+ return MultiModalData(**vision_data)
155
+
156
+ @property
157
+ def model_input_names(self):
158
+ tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
159
+ image_processor_input_names = self.image_processor.model_input_names
160
+
161
+ image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
162
+ return list(tokenizer_input_names + image_processor_input_names)
163
+
164
+
165
+ __all__ = ["Gemma3Processor"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_youtu import *
22
+ from .modeling_youtu import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/configuration_youtu.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/youtu/modular_youtu.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_youtu.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 the Tencent and HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
10
+ # and OPT implementations in this library. It has been modified from its
11
+ # original forms to accommodate minor architectural differences compared
12
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
13
+ #
14
+ # Licensed under the Apache License, Version 2.0 (the "License");
15
+ # you may not use this file except in compliance with the License.
16
+ # You may obtain a copy of the License at
17
+ #
18
+ # http://www.apache.org/licenses/LICENSE-2.0
19
+ #
20
+ # Unless required by applicable law or agreed to in writing, software
21
+ # distributed under the License is distributed on an "AS IS" BASIS,
22
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
+ # See the License for the specific language governing permissions and
24
+ # limitations under the License.
25
+
26
+
27
+ from huggingface_hub.dataclasses import strict
28
+
29
+ from ...configuration_utils import PreTrainedConfig
30
+ from ...modeling_rope_utils import RopeParameters
31
+ from ...utils import auto_docstring
32
+
33
+
34
+ @auto_docstring(checkpoint="tencent/Youtu-LLM-2B")
35
+ @strict
36
+ class YoutuConfig(PreTrainedConfig):
37
+ r"""
38
+ rope_interleave (`bool`, *optional*, defaults to `True`):
39
+ Whether to interleave the rotary position embeddings.
40
+ embedding_initializer_range (`float`, *optional*):
41
+ The standard deviation of the truncated_normal_initializer for initializing all embedding matrices.
42
+
43
+ ```python
44
+ >>> from transformers import YoutuModel, YoutuConfig
45
+ >>> # Initializing a Youtu-LLM-2B style configuration
46
+ >>> configuration = YoutuConfig()
47
+ >>> # Accessing the model configuration
48
+ >>> configuration = model.config
49
+ ```"""
50
+
51
+ model_type = "youtu"
52
+ keys_to_ignore_at_inference = ["past_key_values"]
53
+ base_model_tp_plan = {
54
+ "layers.*.mlp.gate_proj": "colwise",
55
+ "layers.*.mlp.up_proj": "colwise",
56
+ "layers.*.mlp.down_proj": "rowwise",
57
+ }
58
+ base_model_pp_plan = {
59
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
60
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
61
+ "norm": (["hidden_states"], ["hidden_states"]),
62
+ }
63
+ attribute_map = {}
64
+
65
+ vocab_size: int = 128256
66
+ hidden_size: int = 2048
67
+ intermediate_size: int = 6144
68
+ num_hidden_layers: int = 32
69
+ num_attention_heads: int = 16
70
+ num_key_value_heads: int = 16
71
+ kv_lora_rank: int = 512
72
+ q_lora_rank: int | None = 1536
73
+ qk_rope_head_dim: int = 64
74
+ v_head_dim: int | None = 128
75
+ qk_nope_head_dim: int = 128
76
+ hidden_act: str = "silu"
77
+ max_position_embeddings: int = 131072
78
+ initializer_range: float | None = None
79
+ rms_norm_eps: float = 1e-6
80
+ use_cache: bool = True
81
+ pad_token_id: int | None = None
82
+ bos_token_id: int | None = 128000
83
+ eos_token_id: int | list[int] | None = 128001
84
+ tie_word_embeddings: bool = True
85
+ rope_parameters: RopeParameters | dict | None = None
86
+ rope_interleave: bool | None = True
87
+ attention_bias: bool = False
88
+ attention_dropout: float | int | None = 0.0
89
+ embedding_initializer_range: float | None = None
90
+
91
+ def __post_init__(self, **kwargs):
92
+ if self.initializer_range is None:
93
+ if self.hidden_size != 0:
94
+ self.initializer_range = 2.0 / (5.0 * self.hidden_size) ** 0.5
95
+ else:
96
+ self.initializer_range = 0.02
97
+
98
+ self.embedding_initializer_range = self.embedding_initializer_range or 2.0 * self.initializer_range
99
+ if self.num_key_value_heads is None:
100
+ self.num_key_value_heads = self.num_attention_heads
101
+
102
+ self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
103
+ self.head_dim = self.qk_rope_head_dim
104
+ super().__post_init__(**kwargs)
105
+
106
+
107
+ __all__ = ["YoutuConfig"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/modeling_youtu.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/youtu/modular_youtu.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_youtu.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 the Tencent and HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
10
+ # and OPT implementations in this library. It has been modified from its
11
+ # original forms to accommodate minor architectural differences compared
12
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
13
+ #
14
+ # Licensed under the Apache License, Version 2.0 (the "License");
15
+ # you may not use this file except in compliance with the License.
16
+ # You may obtain a copy of the License at
17
+ #
18
+ # http://www.apache.org/licenses/LICENSE-2.0
19
+ #
20
+ # Unless required by applicable law or agreed to in writing, software
21
+ # distributed under the License is distributed on an "AS IS" BASIS,
22
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
+ # See the License for the specific language governing permissions and
24
+ # limitations under the License.
25
+
26
+
27
+ import math
28
+ from collections.abc import Callable
29
+ from typing import Optional
30
+
31
+ import torch
32
+ import torch.nn.functional as F
33
+ from torch import nn
34
+
35
+ from ... import initialization as init
36
+ from ...activations import ACT2FN
37
+ from ...cache_utils import Cache, DynamicCache
38
+ from ...generation import GenerationMixin
39
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
40
+ from ...masking_utils import create_causal_mask
41
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
42
+ from ...modeling_layers import GradientCheckpointingLayer
43
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
44
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
45
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
46
+ from ...processing_utils import Unpack
47
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
48
+ from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults
49
+ from ...utils.output_capturing import capture_outputs
50
+ from .configuration_youtu import YoutuConfig
51
+
52
+
53
+ @use_kernel_forward_from_hub("RMSNorm")
54
+ class YoutuRMSNorm(nn.Module):
55
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
56
+ """
57
+ YoutuRMSNorm is equivalent to T5LayerNorm
58
+ """
59
+ super().__init__()
60
+ self.weight = nn.Parameter(torch.ones(hidden_size))
61
+ self.variance_epsilon = eps
62
+
63
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
64
+ input_dtype = hidden_states.dtype
65
+ hidden_states = hidden_states.to(torch.float32)
66
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
67
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
68
+ return self.weight * hidden_states.to(input_dtype)
69
+
70
+ def extra_repr(self):
71
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
72
+
73
+
74
+ class YoutuRotaryEmbedding(nn.Module):
75
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
76
+
77
+ def __init__(self, config: YoutuConfig, device=None):
78
+ super().__init__()
79
+ self.max_seq_len_cached = config.max_position_embeddings
80
+ self.original_max_seq_len = config.max_position_embeddings
81
+
82
+ self.config = config
83
+
84
+ self.rope_type = self.config.rope_parameters["rope_type"]
85
+ rope_init_fn: Callable = self.compute_default_rope_parameters
86
+ if self.rope_type != "default":
87
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
88
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
89
+
90
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
91
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
92
+
93
+ @staticmethod
94
+ def compute_default_rope_parameters(
95
+ config: YoutuConfig | None = None,
96
+ device: Optional["torch.device"] = None,
97
+ seq_len: int | None = None,
98
+ ) -> tuple["torch.Tensor", float]:
99
+ """
100
+ Computes the inverse frequencies according to the original RoPE implementation
101
+ Args:
102
+ config ([`~transformers.PreTrainedConfig`]):
103
+ The model configuration.
104
+ device (`torch.device`):
105
+ The device to use for initialization of the inverse frequencies.
106
+ seq_len (`int`, *optional*):
107
+ The current sequence length. Unused for this type of RoPE.
108
+ Returns:
109
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
110
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
111
+ """
112
+ base = config.rope_parameters["rope_theta"]
113
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
114
+
115
+ attention_factor = 1.0 # Unused in this type of RoPE
116
+
117
+ # Compute the inverse frequencies
118
+ inv_freq = 1.0 / (
119
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
120
+ )
121
+ return inv_freq, attention_factor
122
+
123
+ @torch.no_grad()
124
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
125
+ def forward(self, x, position_ids):
126
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
127
+ position_ids_expanded = position_ids[:, None, :].float()
128
+
129
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
130
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
131
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
132
+ emb = torch.cat((freqs, freqs), dim=-1)
133
+ cos = emb.cos() * self.attention_scaling
134
+ sin = emb.sin() * self.attention_scaling
135
+
136
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
137
+
138
+
139
+ class YoutuMLP(nn.Module):
140
+ def __init__(self, config):
141
+ super().__init__()
142
+ self.config = config
143
+ self.hidden_size = config.hidden_size
144
+ self.intermediate_size = config.intermediate_size
145
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
146
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
147
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
148
+ self.act_fn = ACT2FN[config.hidden_act]
149
+
150
+ def forward(self, x):
151
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
152
+ return down_proj
153
+
154
+
155
+ def rotate_half(x):
156
+ """Rotates half the hidden dims of the input."""
157
+ x1 = x[..., : x.shape[-1] // 2]
158
+ x2 = x[..., x.shape[-1] // 2 :]
159
+ return torch.cat((-x2, x1), dim=-1)
160
+
161
+
162
+ @use_kernel_func_from_hub("rotary_pos_emb")
163
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
164
+ """Applies Rotary Position Embedding to the query and key tensors.
165
+
166
+ Args:
167
+ q (`torch.Tensor`): The query tensor.
168
+ k (`torch.Tensor`): The key tensor.
169
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
170
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
171
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
172
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
173
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
174
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
175
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
176
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
177
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
178
+ Returns:
179
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
180
+ """
181
+ cos = cos.unsqueeze(unsqueeze_dim)
182
+ sin = sin.unsqueeze(unsqueeze_dim)
183
+ q_embed = (q * cos) + (rotate_half(q) * sin)
184
+ k_embed = (k * cos) + (rotate_half(k) * sin)
185
+ return q_embed, k_embed
186
+
187
+
188
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
189
+ """
190
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
191
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
192
+ """
193
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
194
+ if n_rep == 1:
195
+ return hidden_states
196
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
197
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
198
+
199
+
200
+ def eager_attention_forward(
201
+ module: nn.Module,
202
+ query: torch.Tensor,
203
+ key: torch.Tensor,
204
+ value: torch.Tensor,
205
+ attention_mask: torch.Tensor | None,
206
+ scaling: float,
207
+ dropout: float = 0.0,
208
+ **kwargs: Unpack[TransformersKwargs],
209
+ ):
210
+ key_states = repeat_kv(key, module.num_key_value_groups)
211
+ value_states = repeat_kv(value, module.num_key_value_groups)
212
+
213
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
214
+ if attention_mask is not None:
215
+ attn_weights = attn_weights + attention_mask
216
+
217
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
218
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
219
+ attn_output = torch.matmul(attn_weights, value_states)
220
+ attn_output = attn_output.transpose(1, 2).contiguous()
221
+
222
+ return attn_output, attn_weights
223
+
224
+
225
+ def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
226
+ r"""
227
+ TODO let's just use the original freqcis computation to not have the view
228
+ transpose + reshape! This is not optimized!
229
+ Applies Rotary Position Embedding to the query and key tensors.
230
+
231
+ Args:
232
+ q (`torch.Tensor`): The query tensor.
233
+ k (`torch.Tensor`): The key tensor.
234
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
235
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
236
+ position_ids (`torch.Tensor`):
237
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
238
+ used to pass offsetted position ids when working with a KV-cache.
239
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
240
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
241
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
242
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
243
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
244
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
245
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
246
+ Returns:
247
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
248
+ """
249
+ cos = cos.unsqueeze(unsqueeze_dim)
250
+ sin = sin.unsqueeze(unsqueeze_dim)
251
+
252
+ b, h, s, d = q.shape
253
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
254
+
255
+ b, h, s, d = k.shape
256
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
257
+
258
+ q_embed = (q * cos) + (rotate_half(q) * sin)
259
+ k_embed = (k * cos) + (rotate_half(k) * sin)
260
+ return q_embed, k_embed
261
+
262
+
263
+ def yarn_get_mscale(scale=1, mscale=1):
264
+ if scale <= 1:
265
+ return 1.0
266
+ return 0.1 * mscale * math.log(scale) + 1.0
267
+
268
+
269
+ class YoutuAttention(nn.Module):
270
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
271
+
272
+ def __init__(self, config: YoutuConfig, layer_idx: int):
273
+ super().__init__()
274
+ self.config = config
275
+ self.layer_idx = layer_idx
276
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
277
+ self.attention_dropout = config.attention_dropout
278
+ self.num_heads = config.num_attention_heads
279
+
280
+ self.q_lora_rank = config.q_lora_rank
281
+ self.qk_rope_head_dim = config.qk_rope_head_dim
282
+ self.kv_lora_rank = config.kv_lora_rank
283
+ self.v_head_dim = config.v_head_dim
284
+ self.qk_nope_head_dim = config.qk_nope_head_dim
285
+ self.qk_head_dim = config.qk_head_dim
286
+
287
+ self.is_causal = True
288
+ if self.q_lora_rank is None:
289
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
290
+ else:
291
+ self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
292
+ self.q_a_layernorm = YoutuRMSNorm(config.q_lora_rank)
293
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
294
+
295
+ self.kv_a_proj_with_mqa = nn.Linear(
296
+ config.hidden_size,
297
+ self.kv_lora_rank + self.qk_rope_head_dim,
298
+ bias=config.attention_bias,
299
+ )
300
+ self.kv_a_layernorm = YoutuRMSNorm(self.kv_lora_rank)
301
+ self.kv_b_proj = nn.Linear(
302
+ self.kv_lora_rank,
303
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
304
+ bias=False,
305
+ )
306
+
307
+ self.o_proj = nn.Linear(
308
+ self.num_heads * self.v_head_dim,
309
+ config.hidden_size,
310
+ bias=config.attention_bias,
311
+ )
312
+
313
+ self.scaling = self.qk_head_dim ** (-0.5)
314
+ if self.config.rope_parameters.get("rope_type", "default") != "default":
315
+ mscale_all_dim = self.config.rope_parameters.get("mscale_all_dim", 0)
316
+ scaling_factor = self.config.rope_parameters["factor"]
317
+ if mscale_all_dim:
318
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
319
+ self.scaling = self.scaling * mscale * mscale
320
+
321
+ def forward(
322
+ self,
323
+ hidden_states: torch.Tensor,
324
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
325
+ attention_mask: torch.Tensor | None,
326
+ past_key_values: Cache | None = None,
327
+ **kwargs: Unpack[FlashAttentionKwargs],
328
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
329
+ batch_size, seq_length = hidden_states.shape[:-1]
330
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
331
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
332
+
333
+ if self.q_lora_rank is None:
334
+ q_states = self.q_proj(hidden_states)
335
+ else:
336
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
337
+ q_states = q_states.view(query_shape).transpose(1, 2)
338
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
339
+
340
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
341
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
342
+
343
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
344
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
345
+
346
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
347
+
348
+ cos, sin = position_embeddings
349
+ if self.config.rope_interleave: # support using interleaved weights for efficiency
350
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
351
+ else:
352
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
353
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
354
+
355
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
356
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
357
+
358
+ if past_key_values is not None:
359
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
360
+
361
+ if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
362
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
363
+
364
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
365
+ self.config._attn_implementation, eager_attention_forward
366
+ )
367
+
368
+ attn_output, attn_weights = attention_interface(
369
+ self,
370
+ query_states,
371
+ key_states,
372
+ value_states,
373
+ attention_mask,
374
+ dropout=0.0 if not self.training else self.attention_dropout,
375
+ scaling=self.scaling,
376
+ **kwargs,
377
+ )
378
+
379
+ if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
380
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
381
+
382
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
383
+ attn_output = self.o_proj(attn_output)
384
+ return attn_output, attn_weights
385
+
386
+
387
+ class YoutuDecoderLayer(GradientCheckpointingLayer):
388
+ def __init__(self, config: YoutuConfig, layer_idx: int):
389
+ super().__init__()
390
+ self.hidden_size = config.hidden_size
391
+
392
+ self.self_attn = YoutuAttention(config=config, layer_idx=layer_idx)
393
+
394
+ self.mlp = YoutuMLP(config)
395
+ self.input_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
396
+ self.post_attention_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
397
+
398
+ def forward(
399
+ self,
400
+ hidden_states: torch.Tensor,
401
+ attention_mask: torch.Tensor | None = None,
402
+ position_ids: torch.LongTensor | None = None,
403
+ past_key_values: Cache | None = None,
404
+ use_cache: bool | None = False,
405
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
406
+ **kwargs: Unpack[TransformersKwargs],
407
+ ) -> torch.Tensor:
408
+ residual = hidden_states
409
+ hidden_states = self.input_layernorm(hidden_states)
410
+ # Self Attention
411
+ hidden_states, _ = self.self_attn(
412
+ hidden_states=hidden_states,
413
+ attention_mask=attention_mask,
414
+ position_ids=position_ids,
415
+ past_key_values=past_key_values,
416
+ use_cache=use_cache,
417
+ position_embeddings=position_embeddings,
418
+ **kwargs,
419
+ )
420
+ hidden_states = residual + hidden_states
421
+
422
+ # Fully Connected
423
+ residual = hidden_states
424
+ hidden_states = self.post_attention_layernorm(hidden_states)
425
+ hidden_states = self.mlp(hidden_states)
426
+ hidden_states = residual + hidden_states
427
+ return hidden_states
428
+
429
+
430
+ @auto_docstring
431
+ class YoutuPreTrainedModel(PreTrainedModel):
432
+ config: YoutuConfig
433
+ base_model_prefix = "model"
434
+ supports_gradient_checkpointing = True
435
+ _no_split_modules = ["YoutuDecoderLayer"]
436
+ _skip_keys_device_placement = ["past_key_values"]
437
+ _supports_flash_attn = True
438
+ _supports_sdpa = True
439
+ _supports_flex_attn = True
440
+
441
+ _can_compile_fullgraph = True
442
+ _supports_attention_backend = True
443
+ _can_record_outputs = {
444
+ "hidden_states": YoutuDecoderLayer,
445
+ "attentions": YoutuAttention,
446
+ }
447
+
448
+ @torch.no_grad()
449
+ def _init_weights(self, module):
450
+ super()._init_weights(module)
451
+ std = getattr(self.config, "initializer_range", 0.02)
452
+ embed_std = getattr(self.config, "embedding_initializer_range", 2 * std)
453
+ if isinstance(module, nn.Embedding):
454
+ init.normal_(module.weight, mean=0.0, std=embed_std)
455
+ if module.padding_idx is not None:
456
+ init.zeros_(module.weight.data[module.padding_idx])
457
+
458
+
459
+ @auto_docstring
460
+ class YoutuModel(YoutuPreTrainedModel):
461
+ def __init__(self, config: YoutuConfig):
462
+ super().__init__(config)
463
+ self.padding_idx = config.pad_token_id
464
+ self.vocab_size = config.vocab_size
465
+
466
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
467
+ self.layers = nn.ModuleList(
468
+ [YoutuDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
469
+ )
470
+ self.norm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
471
+ self.rotary_emb = YoutuRotaryEmbedding(config=config)
472
+ self.gradient_checkpointing = False
473
+
474
+ # Initialize weights and apply final processing
475
+ self.post_init()
476
+
477
+ @merge_with_config_defaults
478
+ @capture_outputs
479
+ @auto_docstring
480
+ def forward(
481
+ self,
482
+ input_ids: torch.LongTensor | None = None,
483
+ attention_mask: torch.Tensor | None = None,
484
+ position_ids: torch.LongTensor | None = None,
485
+ past_key_values: Cache | None = None,
486
+ inputs_embeds: torch.FloatTensor | None = None,
487
+ use_cache: bool | None = None,
488
+ **kwargs: Unpack[TransformersKwargs],
489
+ ) -> BaseModelOutputWithPast:
490
+ if (input_ids is None) ^ (inputs_embeds is not None):
491
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
492
+
493
+ if inputs_embeds is None:
494
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
495
+
496
+ if use_cache and past_key_values is None:
497
+ past_key_values = DynamicCache(config=self.config)
498
+
499
+ if position_ids is None:
500
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
501
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
502
+ position_ids = position_ids.unsqueeze(0)
503
+
504
+ causal_mask = create_causal_mask(
505
+ config=self.config,
506
+ inputs_embeds=inputs_embeds,
507
+ attention_mask=attention_mask,
508
+ past_key_values=past_key_values,
509
+ position_ids=position_ids,
510
+ )
511
+
512
+ hidden_states = inputs_embeds
513
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
514
+
515
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
516
+ hidden_states = decoder_layer(
517
+ hidden_states,
518
+ attention_mask=causal_mask,
519
+ position_embeddings=position_embeddings,
520
+ position_ids=position_ids,
521
+ past_key_values=past_key_values,
522
+ use_cache=use_cache,
523
+ **kwargs,
524
+ )
525
+
526
+ hidden_states = self.norm(hidden_states)
527
+ return BaseModelOutputWithPast(
528
+ last_hidden_state=hidden_states,
529
+ past_key_values=past_key_values,
530
+ )
531
+
532
+
533
+ @auto_docstring
534
+ class YoutuForCausalLM(YoutuPreTrainedModel, GenerationMixin):
535
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
536
+ _tp_plan = {"lm_head": "colwise_gather_output"}
537
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
538
+
539
+ def __init__(self, config):
540
+ super().__init__(config)
541
+ self.model = YoutuModel(config)
542
+ self.vocab_size = config.vocab_size
543
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
544
+
545
+ # Initialize weights and apply final processing
546
+ self.post_init()
547
+
548
+ @can_return_tuple
549
+ @auto_docstring
550
+ def forward(
551
+ self,
552
+ input_ids: torch.LongTensor | None = None,
553
+ attention_mask: torch.Tensor | None = None,
554
+ position_ids: torch.LongTensor | None = None,
555
+ past_key_values: Cache | None = None,
556
+ inputs_embeds: torch.FloatTensor | None = None,
557
+ labels: torch.LongTensor | None = None,
558
+ use_cache: bool | None = None,
559
+ logits_to_keep: int | torch.Tensor = 0,
560
+ **kwargs: Unpack[TransformersKwargs],
561
+ ) -> CausalLMOutputWithPast:
562
+ r"""
563
+ Example:
564
+
565
+ ```python
566
+ >>> from transformers import AutoTokenizer, YoutuForCausalLM
567
+
568
+ >>> model = YoutuForCausalLM.from_pretrained("meta-youtu/Youtu-2-7b-hf")
569
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-youtu/Youtu-2-7b-hf")
570
+
571
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
572
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
573
+
574
+ >>> # Generate
575
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
576
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
577
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
578
+ ```"""
579
+ outputs: BaseModelOutputWithPast = self.model(
580
+ input_ids=input_ids,
581
+ attention_mask=attention_mask,
582
+ position_ids=position_ids,
583
+ past_key_values=past_key_values,
584
+ inputs_embeds=inputs_embeds,
585
+ use_cache=use_cache,
586
+ **kwargs,
587
+ )
588
+
589
+ hidden_states = outputs.last_hidden_state
590
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
591
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
592
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
593
+
594
+ loss = None
595
+ if labels is not None:
596
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
597
+
598
+ return CausalLMOutputWithPast(
599
+ loss=loss,
600
+ logits=logits,
601
+ past_key_values=outputs.past_key_values,
602
+ hidden_states=outputs.hidden_states,
603
+ attentions=outputs.attentions,
604
+ )
605
+
606
+
607
+ __all__ = ["YoutuPreTrainedModel", "YoutuModel", "YoutuForCausalLM"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/youtu/modular_youtu.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 the Tencent and HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
+ # and OPT implementations in this library. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+
21
+ import torch
22
+ from huggingface_hub.dataclasses import strict
23
+ from torch import nn
24
+
25
+ from ... import initialization as init
26
+ from ...modeling_utils import PreTrainedModel
27
+ from ...utils import auto_docstring, logging
28
+ from ..deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
29
+ from ..deepseek_v3.modeling_deepseek_v3 import DeepseekV3Attention
30
+ from ..llama.modeling_llama import (
31
+ LlamaDecoderLayer,
32
+ LlamaForCausalLM,
33
+ LlamaModel,
34
+ LlamaPreTrainedModel,
35
+ LlamaRMSNorm,
36
+ LlamaRotaryEmbedding,
37
+ )
38
+ from ..qwen3.modeling_qwen3 import Qwen3MLP
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ @auto_docstring(checkpoint="tencent/Youtu-LLM-2B")
45
+ @strict
46
+ class YoutuConfig(DeepseekV3Config):
47
+ r"""
48
+ rope_interleave (`bool`, *optional*, defaults to `True`):
49
+ Whether to interleave the rotary position embeddings.
50
+ embedding_initializer_range (`float`, *optional*):
51
+ The standard deviation of the truncated_normal_initializer for initializing all embedding matrices.
52
+
53
+ ```python
54
+ >>> from transformers import YoutuModel, YoutuConfig
55
+ >>> # Initializing a Youtu-LLM-2B style configuration
56
+ >>> configuration = YoutuConfig()
57
+ >>> # Accessing the model configuration
58
+ >>> configuration = model.config
59
+ ```"""
60
+
61
+ model_type = "youtu"
62
+ base_model_tp_plan = {
63
+ "layers.*.mlp.gate_proj": "colwise",
64
+ "layers.*.mlp.up_proj": "colwise",
65
+ "layers.*.mlp.down_proj": "rowwise",
66
+ }
67
+ attribute_map = {}
68
+
69
+ vocab_size: int = 128256
70
+ hidden_size: int = 2048
71
+ intermediate_size: int = 6144
72
+ num_hidden_layers: int = 32
73
+ num_attention_heads: int = 16
74
+ num_key_value_heads: int = 16
75
+ max_position_embeddings: int = 131072
76
+ initializer_range: float | None = None
77
+ embedding_initializer_range: float | None = None
78
+ pad_token_id: int | None = None
79
+ bos_token_id: int | None = 128000
80
+ eos_token_id: int | list[int] | None = 128001
81
+ tie_word_embeddings: bool = True
82
+
83
+ # remove unused attribute
84
+ n_shared_experts = AttributeError()
85
+ n_routed_experts = AttributeError()
86
+ routed_scaling_factor = AttributeError()
87
+ n_group = AttributeError()
88
+ topk_group = AttributeError()
89
+ num_experts_per_tok = AttributeError()
90
+ first_k_dense_replace = AttributeError()
91
+ norm_topk_prob = AttributeError()
92
+ pretraining_tp = AttributeError()
93
+ moe_intermediate_size = AttributeError()
94
+
95
+ def __post_init__(self, **kwargs):
96
+ if self.initializer_range is None:
97
+ if self.hidden_size != 0:
98
+ self.initializer_range = 2.0 / (5.0 * self.hidden_size) ** 0.5
99
+ else:
100
+ self.initializer_range = 0.02
101
+
102
+ self.embedding_initializer_range = self.embedding_initializer_range or 2.0 * self.initializer_range
103
+ super().__post_init__(**kwargs)
104
+
105
+
106
+ class YoutuRMSNorm(LlamaRMSNorm):
107
+ pass
108
+
109
+
110
+ class YoutuRotaryEmbedding(LlamaRotaryEmbedding):
111
+ pass
112
+
113
+
114
+ class YoutuMLP(Qwen3MLP):
115
+ pass
116
+
117
+
118
+ class YoutuAttention(DeepseekV3Attention):
119
+ pass
120
+
121
+
122
+ class YoutuDecoderLayer(LlamaDecoderLayer):
123
+ pass
124
+
125
+
126
+ class YoutuPreTrainedModel(LlamaPreTrainedModel, PreTrainedModel):
127
+ @torch.no_grad()
128
+ def _init_weights(self, module):
129
+ PreTrainedModel._init_weights(self, module)
130
+ std = getattr(self.config, "initializer_range", 0.02)
131
+ embed_std = getattr(self.config, "embedding_initializer_range", 2 * std)
132
+ if isinstance(module, nn.Embedding):
133
+ init.normal_(module.weight, mean=0.0, std=embed_std)
134
+ if module.padding_idx is not None:
135
+ init.zeros_(module.weight.data[module.padding_idx])
136
+
137
+
138
+ class YoutuModel(LlamaModel):
139
+ pass
140
+
141
+
142
+ class YoutuForCausalLM(LlamaForCausalLM):
143
+ pass
144
+
145
+
146
+ __all__ = [
147
+ "YoutuConfig",
148
+ "YoutuPreTrainedModel",
149
+ "YoutuModel",
150
+ "YoutuForCausalLM",
151
+ ]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/testing_utils.py ADDED
The diff for this file is too large to render. See raw diff
 
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/training_args.py ADDED
The diff for this file is too large to render. See raw diff
 
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_070000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f9d782f9b8c989c4bdaccc104827d9b426cdf6c42c2bd671b6e40541ceb24c2
3
+ size 897562466
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_078000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68bdb7c05a6e6a3e90081c66c23125a3951666f75f5171b90f4ca8e23da89f69
3
+ size 897562466
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_079000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c31ea7165f98543157bb82e8ff15f8183ee06eaa5cfa6901c6ed884fbe5e0ec
3
+ size 897562466
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_343000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a4d7035ba11f7286021fe81bbe41d62ea19e278f5466a9c6506e495e025e2e6
3
+ size 897562466
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_352000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cf82b122d41a279824fef01b76fe887328610fe616cccbda5186d6bed45dd29
3
+ size 897562466
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_390000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84de046c4742e1f4294ba4ee53edbadd466c91ff6003748fbc7e2bcfc56e1d32
3
+ size 897562466
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_433000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71cd93a38ffcecccd3b1e054151774919038ec40a6aec162ebcfa5db27fa7b46
3
+ size 897562466
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_471000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b1caa92ac5cea91c492b668a52dbce1d34b2cf72db6edbd2ee768d2608cf203
3
+ size 897562466
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_565000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ebd372ab29ab0dde7b3aab27de2e20487c5fce1628a03df9d46b59795702246
3
+ size 897562466
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr3e3_ema0p9999_elfopt_not5_bottleneck16_unfixed_norm_stateprobadd_selfcond_ce_fast_20260612_030202/step_571000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:477f5f24e284fcf6630d7a64ed678211de0b09b4ab11f149bc02dd43c78678da
3
+ size 897562466