StarMist0012 commited on
Commit
e2bfccc
·
verified ·
1 Parent(s): 388fd6e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +129 -95
  2. artifacts/configs/pretrain.yaml +86 -0
  3. artifacts/configs/sft.yaml +92 -0
  4. artifacts/diagnostics/activation_probe_pretrain_final.json +118 -0
  5. artifacts/diagnostics/generation_samples_pretrain_final.json +23 -0
  6. artifacts/diagnostics/generation_samples_sft_final.json +23 -0
  7. artifacts/gpu_telemetry_nvidia_smi.csv +0 -0
  8. artifacts/loss_curve_summary.json +26 -0
  9. artifacts/model_card.json +13 -0
  10. artifacts/run.sh +55 -0
  11. artifacts/run_plan.json +19 -0
  12. artifacts/status.json +1 -0
  13. code/TaoTrain/.gitignore +150 -0
  14. code/TaoTrain/README.md +423 -0
  15. code/TaoTrain/configs/pretrain.yaml +138 -0
  16. code/TaoTrain/configs/rl_dpo.yaml +60 -0
  17. code/TaoTrain/configs/sft.yaml +93 -0
  18. code/TaoTrain/configs/ssm_pretrain.yaml +87 -0
  19. code/TaoTrain/configs/tokenizer.yaml +44 -0
  20. code/TaoTrain/configs/tokenizer_taodata_pilot.yaml +22 -0
  21. code/TaoTrain/configs/yarn_pretrain.yaml +181 -0
  22. code/TaoTrain/data/sample_data.jsonl +10 -0
  23. code/TaoTrain/data/sample_sft_multi.jsonl +4 -0
  24. code/TaoTrain/data/sample_sft_single.jsonl +5 -0
  25. code/TaoTrain/pyproject.toml +65 -0
  26. code/TaoTrain/scripts/benchmark_taonet_real_tokens.py +1034 -0
  27. code/TaoTrain/scripts/benchmark_taonet_token_variants.py +516 -0
  28. code/TaoTrain/scripts/diagnostics/activation_probe.py +160 -0
  29. code/TaoTrain/scripts/diagnostics/generate_checkpoint_samples.py +134 -0
  30. code/TaoTrain/scripts/diagnostics/sft_sanity_check.py +300 -0
  31. code/TaoTrain/scripts/profile_taonet_components.py +393 -0
  32. code/TaoTrain/scripts/remote/job_status.sh +36 -0
  33. code/TaoTrain/scripts/remote/run_200m_base_suite.sh +149 -0
  34. code/TaoTrain/scripts/remote/run_200m_branch_only_chat.sh +329 -0
  35. code/TaoTrain/scripts/remote/run_200m_until_selection.sh +81 -0
  36. code/TaoTrain/scripts/remote/run_pre_200m_branch_only_gate.sh +269 -0
  37. code/TaoTrain/scripts/remote/run_pre_200m_stability_gate.sh +267 -0
  38. code/TaoTrain/scripts/remote/submit_200m_base_suite.sh +16 -0
  39. code/TaoTrain/scripts/remote/submit_200m_branch_only_chat.sh +32 -0
  40. code/TaoTrain/scripts/remote/submit_200m_until_selection.sh +28 -0
  41. code/TaoTrain/scripts/remote/submit_detached_job.sh +118 -0
  42. code/TaoTrain/scripts/remote/submit_pre_200m_branch_only_gate.sh +33 -0
  43. code/TaoTrain/scripts/remote/submit_pre_200m_stability_gate.sh +32 -0
  44. code/TaoTrain/scripts/remote/submit_ssm_improvement_sweep.sh +71 -0
  45. code/TaoTrain/scripts/summarize_taonet_benchmark_suite.py +138 -0
  46. code/TaoTrain/src/taoTrain.egg-info/dependency_links.txt +1 -0
  47. code/TaoTrain/src/taoTrain.egg-info/entry_points.txt +4 -0
  48. code/TaoTrain/src/taoTrain/__init__.py +8 -0
  49. code/TaoTrain/src/taoTrain/cli.py +387 -0
  50. code/TaoTrain/src/taoTrain/config.py +726 -0
README.md CHANGED
@@ -2,7 +2,7 @@
2
  license: mit
3
  language:
4
  - en
5
- library_name: pytorch
6
  pipeline_tag: text-generation
7
  tags:
8
  - taonet
@@ -11,7 +11,8 @@ tags:
11
  - state-space-model
12
  - dplr
13
  - pytorch
14
- - custom-code
 
15
  - text-generation
16
  - experimental
17
  datasets:
@@ -20,15 +21,93 @@ datasets:
20
 
21
  # TaoNet-mini-T2
22
 
23
- TaoNet-mini-T2 is an experimental 196M-parameter TaoNet-style language model that replaces the attention sequence mixer with a Taotern/Gamma DPLR state-space model (SSM). It is packaged as a runnable TaoTrain checkpoint bundle rather than a native `transformers.AutoModelForCausalLM` repository.
24
 
25
- This release is intended for research backup, reproducibility, and deployment investigation. It includes the model checkpoints, tokenizer, TaoTrain runtime code, Taotern SSM code, run configs, diagnostics, and Windows/Linux helper scripts.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  ## Model Details
28
 
29
  | Field | Value |
30
- |---|---|
31
- | Model family | TaoNet / Taotern SSM |
32
  | Architecture | `taonet_ssm` |
33
  | Candidate | `pure_ssm_196m_branch_rms_only` |
34
  | Parameters | 196,573,128 |
@@ -45,20 +124,17 @@ This release is intended for research backup, reproducibility, and deployment in
45
  | Local shift | Enabled, per-channel |
46
  | Branch RMS norm | Enabled |
47
 
48
- ## Files
49
-
50
- Download and unzip:
51
-
52
- ```text
53
- taotern-200m-branch-only-chat-20260514_handoff.zip
54
- ```
55
-
56
- The zip contains:
57
 
58
  ```text
 
 
 
 
 
59
  model/
60
  final_model.pt # SFT final checkpoint
61
- pretrain_final_model.pt # recommended checkpoint for fixed chat test
62
  tokenizer/
63
  tokenizer.model
64
  tokenizer.vocab
@@ -69,100 +145,64 @@ code/
69
  artifacts/
70
  configs/
71
  diagnostics/
72
- model_card.json
73
- run_plan.json
74
- loss_curve_summary.json
75
- chat_ssm_fixed.py
76
- RUN_TEST_WINDOWS.cmd
77
- setup_windows.ps1
78
- setup_linux.sh
79
- run_chat_fixed.ps1
80
- run_chat_fixed.sh
81
- CHECKSUMS.sha256
82
- README_HANDOFF.md
83
  ```
84
 
85
- Package checksum:
86
-
87
- ```text
88
- sha256=5dd45e5d236f37101d183cadb92c5a13ea0aa574a24f578faa6cc8ee7ed166d6
89
- ```
90
-
91
- ## Quick Start
92
 
93
- ### Windows
94
 
95
- 1. Download and unzip `taotern-200m-branch-only-chat-20260514_handoff.zip`.
96
- 2. Double-click:
97
-
98
- ```text
99
- RUN_TEST_WINDOWS.cmd
100
  ```
101
 
102
- The launcher creates or reuses `.venv`, tries PyTorch CUDA 12.8 first, falls back to CUDA 12.6, installs the packaged TaoTrain/Taotern SSM code, then launches the fixed chat script.
103
-
104
- Manual Windows setup:
105
 
106
  ```powershell
107
- powershell -ExecutionPolicy Bypass -File .\setup_windows.ps1 -TorchFlavor cu128
108
- powershell -ExecutionPolicy Bypass -File .\run_chat_fixed.ps1
109
  ```
110
 
111
- If CUDA 12.8 fails:
112
 
113
- ```powershell
114
- powershell -ExecutionPolicy Bypass -File .\setup_windows.ps1 -TorchFlavor cu126
115
- ```
116
 
117
- CPU-only setup is possible but very slow:
118
 
119
- ```powershell
120
- powershell -ExecutionPolicy Bypass -File .\setup_windows.ps1 -TorchFlavor cpu
 
 
121
  ```
122
 
123
- ### Linux
124
 
125
- ```bash
126
- unzip taotern-200m-branch-only-chat-20260514_handoff.zip
127
- cd taotern-200m-branch-only-chat-20260514
128
- chmod +x setup_linux.sh run_chat_fixed.sh
129
- ./setup_linux.sh cu128
130
- ./run_chat_fixed.sh
131
- ```
132
 
133
- If CUDA 12.8 is not suitable, try:
134
-
135
- ```bash
136
- ./setup_linux.sh cu126
137
- ```
138
-
139
- ## Recommended Inference Mode
140
-
141
- The checkpoint was trained with convolutional SSM execution. For chat testing, this bundle uses a temporary RepoBridge-style inference fix:
142
 
143
  ```text
 
 
 
 
 
144
  ssm_finite_tail_correction=true
145
- ssm_kernel_mode=recurrent
146
- default_checkpoint=model/pretrain_final_model.pt
147
- decode=greedy
148
- top_p=0.85
149
- repetition_penalty=1.2
150
- max_new_tokens=64
151
- ```
152
-
153
- Run it directly:
154
-
155
- ```bash
156
- python chat_ssm_fixed.py --interactive
157
  ```
158
 
159
- or on Windows:
160
 
161
- ```powershell
162
- powershell -ExecutionPolicy Bypass -File .\run_chat_fixed.ps1
163
- ```
 
 
 
 
 
164
 
165
- Plain TaoTrain `tui-chat` reconstructs the model from the training config (`ssm_finite_tail_correction=false`, `ssm_kernel_mode=conv`) and may produce much worse samples. Use `chat_ssm_fixed.py` for the expected behavior.
166
 
167
  ## Training Summary
168
 
@@ -172,8 +212,6 @@ Run ID:
172
  taotern-200m-branch-only-chat-20260514
173
  ```
174
 
175
- Training shape:
176
-
177
  | Stage | Value |
178
  |---|---:|
179
  | Pretrain token positions | 4,000,000,000 |
@@ -204,17 +242,14 @@ This model is intended for:
204
 
205
  - Taotern/TaoNet SSM research
206
  - checkpoint backup and reproducibility
207
- - deployment experiments for a custom TaoTrain runtime
208
  - studying recurrent SSM inference behavior
209
 
210
- It is not currently a drop-in Transformers model and is not intended as a polished production chatbot.
211
-
212
  ## Limitations
213
 
214
- - Experimental model quality; validate before use.
215
- - Requires custom TaoTrain and Taotern SSM code included in the package.
216
  - The recommended chat path depends on an inference-time SSM override.
217
- - Not currently packaged as `AutoModelForCausalLM`.
218
  - CPU inference is expected to be very slow.
219
  - English-focused pilot data/tokenizer.
220
 
@@ -232,4 +267,3 @@ It is not currently a drop-in Transformers model and is not intended as a polish
232
  ## Related
233
 
234
  - [TaoTern/TaoNet-pico-T1](https://huggingface.co/TaoTern/TaoNet-pico-T1)
235
-
 
2
  license: mit
3
  language:
4
  - en
5
+ library_name: transformers
6
  pipeline_tag: text-generation
7
  tags:
8
  - taonet
 
11
  - state-space-model
12
  - dplr
13
  - pytorch
14
+ - transformers
15
+ - custom_code
16
  - text-generation
17
  - experimental
18
  datasets:
 
21
 
22
  # TaoNet-mini-T2
23
 
24
+ TaoNet-mini-T2 is an experimental 196M-parameter TaoNet language model using a Taotern/Gamma DPLR state-space model (SSM) sequence core instead of attention. The repository includes the full training handoff package, but the recommended inference path is now Hugging Face `transformers` remote code:
25
 
26
+ ```python
27
+ AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True)
28
+ ```
29
+
30
+ The default `transformers` loader downloads `model/pretrain_final_model.pt` and applies the RepoBridge chat-quality fix: `ssm_finite_tail_correction=True` and `ssm_kernel_mode="recurrent"`.
31
+
32
+ ## Quick Start
33
+
34
+ Install runtime dependencies:
35
+
36
+ ```bash
37
+ pip install torch transformers sentencepiece huggingface_hub pydantic pydantic-settings pyyaml numpy
38
+ ```
39
+
40
+ For the private review repo, log in first:
41
+
42
+ ```bash
43
+ hf auth login
44
+ ```
45
+
46
+ Run generation from Python:
47
+
48
+ ```python
49
+ import time
50
+ import torch
51
+ from transformers import AutoModelForCausalLM, AutoTokenizer
52
+
53
+ MODEL_NAME = "TaoTern/TaoNet-mini-T2"
54
+
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
57
+
58
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
59
+ model = AutoModelForCausalLM.from_pretrained(
60
+ MODEL_NAME,
61
+ trust_remote_code=True,
62
+ torch_dtype=dtype,
63
+ ).to(device)
64
+
65
+
66
+ def generate_text(prompt, max_new_tokens=64, temperature=0.7, top_p=0.85):
67
+ inputs = tokenizer(prompt, return_tensors="pt")
68
+ inputs = {key: value.to(device) for key, value in inputs.items()}
69
+
70
+ start_time = time.time()
71
+ with torch.inference_mode():
72
+ outputs = model.generate(
73
+ **inputs,
74
+ max_new_tokens=max_new_tokens,
75
+ temperature=temperature,
76
+ top_p=top_p,
77
+ repetition_penalty=1.2,
78
+ do_sample=True,
79
+ use_cache=False,
80
+ pad_token_id=tokenizer.pad_token_id,
81
+ eos_token_id=tokenizer.eos_token_id,
82
+ )
83
+ elapsed_time = time.time() - start_time
84
+
85
+ new_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
86
+ tokens_per_second = new_tokens / elapsed_time if elapsed_time > 0 else 0.0
87
+ completion = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
88
+ return completion, tokens_per_second
89
+
90
+
91
+ if __name__ == "__main__":
92
+ text, tps = generate_text("Fruit is now expensive so we should")
93
+ print(text)
94
+ print(f"\nTokens per second: {tps:.2f}")
95
+ ```
96
+
97
+ To load the SFT final checkpoint instead of the default pretrain checkpoint:
98
+
99
+ ```python
100
+ model = AutoModelForCausalLM.from_pretrained(
101
+ "TaoTern/TaoNet-mini-T2",
102
+ trust_remote_code=True,
103
+ checkpoint_name="final_model.pt",
104
+ )
105
+ ```
106
 
107
  ## Model Details
108
 
109
  | Field | Value |
110
+ |---|---:|
 
111
  | Architecture | `taonet_ssm` |
112
  | Candidate | `pure_ssm_196m_branch_rms_only` |
113
  | Parameters | 196,573,128 |
 
124
  | Local shift | Enabled, per-channel |
125
  | Branch RMS norm | Enabled |
126
 
127
+ ## Repository Layout
 
 
 
 
 
 
 
 
128
 
129
  ```text
130
+ config.json
131
+ configuration_taonet_mini_t2.py
132
+ modeling_taonet_mini_t2.py
133
+ tokenization_taonet_mini_t2.py
134
+ tokenizer.model
135
  model/
136
  final_model.pt # SFT final checkpoint
137
+ pretrain_final_model.pt # default checkpoint for HF inference
138
  tokenizer/
139
  tokenizer.model
140
  tokenizer.vocab
 
145
  artifacts/
146
  configs/
147
  diagnostics/
148
+ chat_ssm_fixed.py # legacy local fixed-chat CLI
149
+ eval_lm_eval.py # local lm-eval harness wrapper
 
 
 
 
 
 
 
 
 
150
  ```
151
 
152
+ ## Upload Notes
 
 
 
 
 
 
153
 
154
+ This repo contains two multi-GB checkpoint files, so prefer the resumable large-folder uploader instead of the normal single-commit upload command:
155
 
156
+ ```bash
157
+ hf upload-large-folder TaoTern/TaoNet-mini-T2 . --repo-type model --private
 
 
 
158
  ```
159
 
160
+ On Windows, from the repo folder:
 
 
161
 
162
  ```powershell
163
+ powershell -ExecutionPolicy Bypass -File .\upload_large_folder.ps1
 
164
  ```
165
 
166
+ ## Inference Notes
167
 
168
+ The training config used `ssm_finite_tail_correction=False` and `ssm_kernel_mode="conv"`. That path is fast for full-sequence training/evaluation but produced poor chat samples in the recovered workflow.
 
 
169
 
170
+ The `transformers` wrapper defaults to:
171
 
172
+ ```text
173
+ ssm_finite_tail_correction=True
174
+ ssm_kernel_mode=recurrent
175
+ checkpoint=model/pretrain_final_model.pt
176
  ```
177
 
178
+ For fast benchmark scoring, use the included `eval_lm_eval.py` script with `--ssm-kernel-mode conv --finite-tail`.
179
 
180
+ ## LM Evaluation Harness Benchmark
 
 
 
 
 
 
181
 
182
+ Settings:
 
 
 
 
 
 
 
 
183
 
184
  ```text
185
+ library=lm-eval-harness
186
+ checkpoint=model/pretrain_final_model.pt
187
+ num_fewshot=0
188
+ limit=100
189
+ ssm_kernel_mode=conv
190
  ssm_finite_tail_correction=true
191
+ eval_batch_size=8
 
 
 
 
 
 
 
 
 
 
 
192
  ```
193
 
194
+ Results:
195
 
196
+ | Task | Primary score |
197
+ |---|---:|
198
+ | HellaSwag | 0.3300 |
199
+ | ARC Easy | 0.3400 |
200
+ | ARC Challenge | 0.2200 |
201
+ | PIQA | 0.4400 |
202
+ | Winogrande | 0.5300 |
203
+ | Mean primary score | 0.3720 |
204
 
205
+ These are limit-100 smoke benchmark numbers for review, not full leaderboard results.
206
 
207
  ## Training Summary
208
 
 
212
  taotern-200m-branch-only-chat-20260514
213
  ```
214
 
 
 
215
  | Stage | Value |
216
  |---|---:|
217
  | Pretrain token positions | 4,000,000,000 |
 
242
 
243
  - Taotern/TaoNet SSM research
244
  - checkpoint backup and reproducibility
245
+ - deployment experiments with custom Hugging Face remote code
246
  - studying recurrent SSM inference behavior
247
 
 
 
248
  ## Limitations
249
 
250
+ - Experimental model quality; validate outputs before use.
251
+ - Requires `trust_remote_code=True` because the architecture is not part of upstream `transformers`.
252
  - The recommended chat path depends on an inference-time SSM override.
 
253
  - CPU inference is expected to be very slow.
254
  - English-focused pilot data/tokenizer.
255
 
 
267
  ## Related
268
 
269
  - [TaoTern/TaoNet-pico-T1](https://huggingface.co/TaoTern/TaoNet-pico-T1)
 
artifacts/configs/pretrain.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ architecture_type: taonet_ssm
3
+ vocab_size: 8192
4
+ hidden_dim: 1024
5
+ num_layers: 18
6
+ num_heads: 8
7
+ max_seq_length: 512
8
+ d_latent_kv: 768
9
+ d_rope: 128
10
+ hidden_dim_ff: 3072
11
+ dropout: 0.0
12
+ gqa_groups: 1
13
+ use_factorized_embedding: false
14
+ d_embed_rank: 96
15
+ init_std: 0.02
16
+ ssm_core: dplr
17
+ ssm_hidden_dim: 32
18
+ ssm_mixer_dim: 256
19
+ ssm_num_lanes: 2
20
+ ssm_lane_combine: channel
21
+ ssm_lane_mode: split
22
+ ssm_split_mix: none
23
+ ssm_rank: 1
24
+ ssm_max_low_rank_scale: 0.1
25
+ ssm_finite_tail_correction: false
26
+ ssm_discretization: bilinear
27
+ ssm_kernel_mode: conv
28
+ ssm_kernel_threshold: 64
29
+ ssm_dt_min: 1e-3
30
+ ssm_dt_max: 1e-1
31
+ ssm_dt_init: 1e-2
32
+ ssm_use_d: true
33
+ ssm_activation: gelu
34
+ ssm_gate: true
35
+ ssm_input_gate: true
36
+ ssm_gate_type: channel
37
+ ssm_use_padding_mask: false
38
+ ssm_layer_scale_init: 0.1
39
+ ssm_branch_rms_norm: true
40
+ block_residual_rms_norm: false
41
+
42
+ ssm_local_shift: true
43
+ ssm_local_shift_init: 0.1
44
+ ssm_local_shift_per_channel: true
45
+
46
+ dataset:
47
+ local: true
48
+ jsonl_path: /home/student/Data/TaoData/pretrain.jsonl
49
+ text_field: text
50
+ tokenizer_type: sentencepiece
51
+ tokenizer_path: /home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model
52
+ samples_per_chunk: 2000
53
+ tokenizer_threads: 8
54
+
55
+ sequence_length: 512
56
+ batch_size: 8
57
+ num_epochs: 100000
58
+ max_steps: 976563
59
+ gradient_accumulation_steps: 1
60
+ max_grad_norm: 1.0
61
+
62
+ optimizer:
63
+ optimizer_type: adamw
64
+ learning_rate: 0.0008
65
+ weight_decay: 0.01
66
+ betas: [0.9, 0.999]
67
+ eps: 1e-8
68
+
69
+ scheduler:
70
+ scheduler_type: constant
71
+ warmup_steps: 1000
72
+ warmup_ratio: 0.0
73
+
74
+ dtype: bfloat16
75
+ device: cuda
76
+ checkpoint_dir: /home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/pretrain
77
+ save_every_steps: 100000
78
+ save_best_model: false
79
+ keep_last_n_checkpoints: 3
80
+ eval_every_steps: 100000
81
+ eval_samples: 32
82
+ log_every_steps: 100
83
+ aim_repo: /home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/.aim-pretrain
84
+ seed: 43
85
+ num_workers: 0
86
+ pin_memory: true
artifacts/configs/sft.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ architecture_type: taonet_ssm
3
+ vocab_size: 8192
4
+ hidden_dim: 1024
5
+ num_layers: 18
6
+ num_heads: 8
7
+ max_seq_length: 512
8
+ d_latent_kv: 768
9
+ d_rope: 128
10
+ hidden_dim_ff: 3072
11
+ dropout: 0.0
12
+ gqa_groups: 1
13
+ use_factorized_embedding: false
14
+ d_embed_rank: 96
15
+ init_std: 0.02
16
+ ssm_core: dplr
17
+ ssm_hidden_dim: 32
18
+ ssm_mixer_dim: 256
19
+ ssm_num_lanes: 2
20
+ ssm_lane_combine: channel
21
+ ssm_lane_mode: split
22
+ ssm_split_mix: none
23
+ ssm_rank: 1
24
+ ssm_max_low_rank_scale: 0.1
25
+ ssm_finite_tail_correction: false
26
+ ssm_discretization: bilinear
27
+ ssm_kernel_mode: conv
28
+ ssm_kernel_threshold: 64
29
+ ssm_dt_min: 1e-3
30
+ ssm_dt_max: 1e-1
31
+ ssm_dt_init: 1e-2
32
+ ssm_use_d: true
33
+ ssm_activation: gelu
34
+ ssm_gate: true
35
+ ssm_input_gate: true
36
+ ssm_gate_type: channel
37
+ ssm_use_padding_mask: false
38
+ ssm_layer_scale_init: 0.1
39
+ ssm_branch_rms_norm: true
40
+ block_residual_rms_norm: false
41
+
42
+ ssm_local_shift: true
43
+ ssm_local_shift_init: 0.1
44
+ ssm_local_shift_per_channel: true
45
+
46
+ dataset:
47
+ split: train
48
+ instruction_column: input
49
+ response_column: output
50
+ local: true
51
+ jsonl_path: /home/student/Data/TaoData/sft.jsonl
52
+ samples_per_chunk: 2000
53
+ tokenizer_type: sentencepiece
54
+ tokenizer_path: /home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model
55
+ tokenizer_threads: 8
56
+
57
+ checkpoint_path: /home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/pretrain/final_model.pt
58
+ user_token: "<user>"
59
+ assistant_token: "<assistant>"
60
+ response_loss_only: true
61
+
62
+ batch_size: 8
63
+ num_epochs: 100000
64
+ max_steps: 50000
65
+ gradient_accumulation_steps: 1
66
+ max_grad_norm: 1.0
67
+
68
+ optimizer:
69
+ optimizer_type: adamw
70
+ learning_rate: 0.00005
71
+ weight_decay: 0.0
72
+ betas: [0.9, 0.999]
73
+ eps: 1e-8
74
+
75
+ scheduler:
76
+ scheduler_type: constant
77
+ warmup_steps: 100
78
+ warmup_ratio: 0.0
79
+
80
+ dtype: bfloat16
81
+ device: cuda
82
+ checkpoint_dir: /home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/sft
83
+ save_every_steps: 10000
84
+ save_best_model: false
85
+ keep_last_n_checkpoints: 3
86
+ eval_every_steps: 10000
87
+ eval_samples: 32
88
+ log_every_steps: 20
89
+ aim_repo: /home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/.aim-sft
90
+ seed: 44
91
+ num_workers: 0
92
+ pin_memory: true
artifacts/diagnostics/activation_probe_pretrain_final.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "checkpoint": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/pretrain/final_model.pt",
3
+ "loss": 2.8459982872009277,
4
+ "batch_size": 2,
5
+ "seq_len": 512,
6
+ "device": "cuda",
7
+ "dtype": "torch.bfloat16",
8
+ "layers": {
9
+ "blocks.0": {
10
+ "numel": 1048576,
11
+ "finite": 1048576,
12
+ "rms": 2.9217934608459473,
13
+ "max_abs": 62.90830612182617
14
+ },
15
+ "blocks.1": {
16
+ "numel": 1048576,
17
+ "finite": 1048576,
18
+ "rms": 3.4865853786468506,
19
+ "max_abs": 63.050193786621094
20
+ },
21
+ "blocks.2": {
22
+ "numel": 1048576,
23
+ "finite": 1048576,
24
+ "rms": 4.092358589172363,
25
+ "max_abs": 98.85921478271484
26
+ },
27
+ "blocks.3": {
28
+ "numel": 1048576,
29
+ "finite": 1048576,
30
+ "rms": 4.814863681793213,
31
+ "max_abs": 160.82785034179688
32
+ },
33
+ "blocks.4": {
34
+ "numel": 1048576,
35
+ "finite": 1048576,
36
+ "rms": 6.032325267791748,
37
+ "max_abs": 237.8623504638672
38
+ },
39
+ "blocks.5": {
40
+ "numel": 1048576,
41
+ "finite": 1048576,
42
+ "rms": 7.5034403800964355,
43
+ "max_abs": 324.9512939453125
44
+ },
45
+ "blocks.6": {
46
+ "numel": 1048576,
47
+ "finite": 1048576,
48
+ "rms": 8.303068161010742,
49
+ "max_abs": 345.4449462890625
50
+ },
51
+ "blocks.7": {
52
+ "numel": 1048576,
53
+ "finite": 1048576,
54
+ "rms": 8.911171913146973,
55
+ "max_abs": 337.6418151855469
56
+ },
57
+ "blocks.8": {
58
+ "numel": 1048576,
59
+ "finite": 1048576,
60
+ "rms": 10.293173789978027,
61
+ "max_abs": 343.7690124511719
62
+ },
63
+ "blocks.9": {
64
+ "numel": 1048576,
65
+ "finite": 1048576,
66
+ "rms": 11.591143608093262,
67
+ "max_abs": 361.6515808105469
68
+ },
69
+ "blocks.10": {
70
+ "numel": 1048576,
71
+ "finite": 1048576,
72
+ "rms": 13.886137962341309,
73
+ "max_abs": 401.029052734375
74
+ },
75
+ "blocks.11": {
76
+ "numel": 1048576,
77
+ "finite": 1048576,
78
+ "rms": 15.372880935668945,
79
+ "max_abs": 406.81036376953125
80
+ },
81
+ "blocks.12": {
82
+ "numel": 1048576,
83
+ "finite": 1048576,
84
+ "rms": 18.466873168945312,
85
+ "max_abs": 541.068603515625
86
+ },
87
+ "blocks.13": {
88
+ "numel": 1048576,
89
+ "finite": 1048576,
90
+ "rms": 20.376676559448242,
91
+ "max_abs": 553.8697509765625
92
+ },
93
+ "blocks.14": {
94
+ "numel": 1048576,
95
+ "finite": 1048576,
96
+ "rms": 23.877737045288086,
97
+ "max_abs": 652.4536743164062
98
+ },
99
+ "blocks.15": {
100
+ "numel": 1048576,
101
+ "finite": 1048576,
102
+ "rms": 28.56619644165039,
103
+ "max_abs": 874.2588500976562
104
+ },
105
+ "blocks.16": {
106
+ "numel": 1048576,
107
+ "finite": 1048576,
108
+ "rms": 35.949851989746094,
109
+ "max_abs": 1143.6116943359375
110
+ },
111
+ "blocks.17": {
112
+ "numel": 1048576,
113
+ "finite": 1048576,
114
+ "rms": 45.9725227355957,
115
+ "max_abs": 2560.034423828125
116
+ }
117
+ }
118
+ }
artifacts/diagnostics/generation_samples_pretrain_final.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "checkpoint": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/pretrain/final_model.pt",
3
+ "tokenizer_path": "/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model",
4
+ "device": "cuda",
5
+ "dtype": "torch.bfloat16",
6
+ "max_new_tokens": 160,
7
+ "temperature": 0.8,
8
+ "top_p": 0.9,
9
+ "samples": [
10
+ {
11
+ "prompt": "The purpose of artificial intelligence is",
12
+ "completion": "to enable human creativity to be a kind of artifact of politics - and sometimes to do that the principle of human dignity cannot be denied without the distinction of two; but the only thing that gives those who rely on it the ability to be free to decide the right thing.In some cases, the majority of the rules about how to use such knowledge are not clearly understood by the law. This leads to the generation of the same arguments, which leads to the mistake of the expert who sought to apply it to everything else.Well, to comply with the human dignity. There is some good news. There are some types of discrimination. This is because it is always a death sentence. Many people are not even able to get to court."
13
+ },
14
+ {
15
+ "prompt": "In a small village,",
16
+ "completion": "you can live alone, but sometimes you might have to leave the house without a roof. Most of these city residents get two hours of sleep on most of the time, while many do.When it comes to school:Times help.Most of the time, wind is always the most important (or the least important): being well rested during the day.When the weather gets hot, the symptoms of insomnia can appear. Most of the time it just doesn't. Most people get enough sleep.Typically, we don't need to get more than 2 hours.Why are there such a large number of people getting the same kind of sleep?There are some factors that make it difficult to identify as having sleep problems. Some people do get a few"
17
+ },
18
+ {
19
+ "prompt": "<user>Hello, who are you?<assistant>",
20
+ "completion": "24.2 (2) 1:10Though many households still live without electricity, most people who live close to the power lines will do so to the extent that they have only a small amount of electricity at the end of the day (33%).(3) The average household electricity rate is higher for the two types of households at 32%, with 31% of the 2005-06 peaks.7Take home message:- Unrealistically high societies. This is one factor that makes getting it to such a huge difference. This means that one has to make some very poor choices when it comes to one's overall lifestyle. Many cities do not have a proper gas taxes;"
21
+ }
22
+ ]
23
+ }
artifacts/diagnostics/generation_samples_sft_final.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "checkpoint": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/sft/final_model.pt",
3
+ "tokenizer_path": "/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model",
4
+ "device": "cuda",
5
+ "dtype": "torch.bfloat16",
6
+ "max_new_tokens": 220,
7
+ "temperature": 0.7,
8
+ "top_p": 0.9,
9
+ "samples": [
10
+ {
11
+ "prompt": "<user>Hello, who are you?<assistant>",
12
+ "completion": "A tip is usually a net, a pack is a scavenger, which means that the initial effort is the delta, the body, the nav, is the ad. ⁇ ** Also, typically, a person typically becomes anemic and always begins to move around, which typically involves movement, which typically involves some movement, such as movement, which can also affect movement. ⁇ **Astrophobic** is not always a result of a very dramatic event. ⁇ * Often, it).That must be.*That got lost. ⁇ That got you got.)That must have got you confused. Most doctors agree. Muscles seem to lose flexibility.But steroids can still keep you sharp. Muscle cramps. Muscle cramps. Muscle fatigue. Most adults can still breathe. ⁇ **Most organs have multiple functions.** ⁇ Most treatments are aimed at improving symptoms. But advanced treatments may trigger inflammation. Low maintenance but soft but solid retain or soft core. Presence: Completing the major: Co the specific major"
13
+ },
14
+ {
15
+ "prompt": "<user>Explain what artificial intelligence is in simple words.<assistant>",
16
+ "completion": "A oscillator is a singular and eventually identical, specialized singular, the silencer. However, sometimes, sometimes, sometimes, the particles are almost always and typically surrounded by unusual or unusual characteristics, such as being very agile, such as being very small, very small, and often very dense, sometimes very really really. ⁇ They can also sometimes also sometimes sometimes sometimes. Sometimes, sometimes sometimes sometimes sometimes. Sometimes, sometimes sometimes sometimes sometimes. ⁇ However,).).....)....That.).G...That.K..K..New..New....New..New......These two devices seem quite close. Each has its own advantage. ⁇ But sometimes. But sometimes. It doesn't. But it still does. But... The hypothesis. The hypothesis. The untested or rare. Comes in red. P). Complex. Co. High. CoinIn the general, the"
17
+ },
18
+ {
19
+ "prompt": "<user>Give me three practical study tips.<assistant>",
20
+ "completion": "Dementia, also known as numbness or tingling under the arc (the gait), is a physical deformity. The necessity is the 11111111 and also typically includes an injury or condition that makes it difficult for you to do some simple things, such as getting dressed, getting some rest, or getting some sleep, but not actually getting really tired. Sometimes simply having this sometimes sometimes sometimes does really does sometimes does. Sometimes, sometimes sometimes sometimes. Sometimes sometimes.)It.The....).)...).That...That.. Two. ⁇ .That.. ⁇ Check it. ⁇ Check everything. ⁇ Check it. ⁇ But keep it small. Remember to keep everything hidden. Remember to keep everything organized. Remember to keep everyone informed. Remember to keep valuable items readily accessible. Most businesses don't offer absolute protection. Let the deco stay solid. Cos. Cos. Co. is a specific sub. ⁇ In a Coin,"
21
+ }
22
+ ]
23
+ }
artifacts/gpu_telemetry_nvidia_smi.csv ADDED
The diff for this file is too large to render. See raw diff
 
artifacts/loss_curve_summary.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pretrain": {
3
+ "points": 961817,
4
+ "first_loss": 9.26,
5
+ "last_loss": 2.64,
6
+ "min_loss": 0.577,
7
+ "max_loss": 9.28,
8
+ "tail20_mean": 2.367,
9
+ "tail100_mean": 2.3350999999999997,
10
+ "tail200_mean": 2.3415500000000002,
11
+ "tail1000_mean": 2.373460000000001,
12
+ "csv_path": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/analysis/pretrain_loss_curve.csv"
13
+ },
14
+ "sft": {
15
+ "points": 49238,
16
+ "first_loss": 3.2,
17
+ "last_loss": 1.08,
18
+ "min_loss": 0.386,
19
+ "max_loss": 3.38,
20
+ "tail20_mean": 0.9390999999999998,
21
+ "tail100_mean": 0.9585200000000004,
22
+ "tail200_mean": 0.965355,
23
+ "tail1000_mean": 0.9520289999999981,
24
+ "csv_path": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/analysis/sft_loss_curve.csv"
25
+ }
26
+ }
artifacts/model_card.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "state": "completed",
3
+ "architecture": "taonet_ssm",
4
+ "candidate": "pure_ssm_196m_branch_rms_only",
5
+ "pretrain_checkpoint": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/pretrain/final_model.pt",
6
+ "sft_checkpoint": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/sft/final_model.pt",
7
+ "tokenizer_path": "/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model",
8
+ "pretrain_config": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/configs/pretrain.yaml",
9
+ "sft_config": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/configs/sft.yaml",
10
+ "activation_probe": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/diagnostics/activation_probe_pretrain_final.json",
11
+ "pretrain_samples": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/diagnostics/generation_samples_pretrain_final.json",
12
+ "sft_samples": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/diagnostics/generation_samples_sft_final.json"
13
+ }
artifacts/run.sh ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -Eeuo pipefail
3
+
4
+ job_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ status_json="$job_dir/status.json"
6
+ log_path="$job_dir/train.log"
7
+ telemetry_path="$job_dir/gpu_telemetry_nvidia_smi.csv"
8
+
9
+ write_status() {
10
+ local state="$1"
11
+ local extra="${2:-}"
12
+ local now
13
+ now="$(date -Iseconds)"
14
+ cat > "$status_json" <<STATUS
15
+ {"state":"$state","updated_at":"$now","job_dir":"$job_dir"$extra}
16
+ STATUS
17
+ }
18
+
19
+ write_status "running" ",\"start_time\":\"$(date -Iseconds)\""
20
+ rm -f "$job_dir/DONE" "$job_dir/FAILED"
21
+
22
+ monitor_pid=""
23
+ if command -v nvidia-smi >/dev/null 2>&1; then
24
+ (
25
+ while true; do
26
+ date -Iseconds
27
+ nvidia-smi --query-gpu=timestamp,name,utilization.gpu,utilization.memory,memory.used,memory.total,power.draw,temperature.gpu --format=csv,noheader,nounits
28
+ sleep 5
29
+ done
30
+ ) > "$telemetry_path" 2>&1 &
31
+ monitor_pid="$!"
32
+ fi
33
+
34
+ cleanup() {
35
+ if [[ -n "$monitor_pid" ]]; then
36
+ kill "$monitor_pid" 2>/dev/null || true
37
+ wait "$monitor_pid" 2>/dev/null || true
38
+ fi
39
+ }
40
+ trap cleanup EXIT
41
+
42
+ set +e
43
+ "$job_dir/command.sh" > "$log_path" 2>&1
44
+ exit_code="$?"
45
+ set -e
46
+
47
+ if [[ "$exit_code" -eq 0 ]]; then
48
+ touch "$job_dir/DONE"
49
+ write_status "completed" ",\"exit_code\":0,\"end_time\":\"$(date -Iseconds)\""
50
+ else
51
+ echo "$exit_code" > "$job_dir/FAILED"
52
+ write_status "failed" ",\"exit_code\":$exit_code,\"end_time\":\"$(date -Iseconds)\""
53
+ fi
54
+
55
+ exit "$exit_code"
artifacts/run_plan.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "purpose": "200m_branch_only_pure_ssm_4b_pretrain_plus_sft_chatbot_attempt",
3
+ "candidate": "pure_ssm_196m_branch_rms_only",
4
+ "pretrain_token_positions": 4000000000,
5
+ "pretrain_steps": 976563,
6
+ "batch_size": 8,
7
+ "seq_len": 512,
8
+ "sft_steps": 50000,
9
+ "pretrain_lr": 0.0008,
10
+ "sft_lr": 0.00005,
11
+ "save_every_steps": 100000,
12
+ "block_residual_rms_cap": "null",
13
+ "notes": [
14
+ "Uses streaming JSONL training through TaoTrain CLI.",
15
+ "Counts token positions as batch_size * seq_len * max_steps.",
16
+ "Keeps SSM branch RMS normalization enabled and block residual RMS normalization disabled.",
17
+ "Runs corrected response-only SFT after pretraining."
18
+ ]
19
+ }
artifacts/status.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"state":"completed","updated_at":"2026-05-15T19:22:45+08:00","job_dir":"/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514","exit_code":0,"end_time":"2026-05-15T19:22:45+08:00"}
code/TaoTrain/.gitignore ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .nox/
42
+ .coverage
43
+ .coverage.*
44
+ .cache
45
+ nosetests.xml
46
+ coverage.xml
47
+ *.cover
48
+ *.py,cover
49
+ .hypothesis/
50
+ .pytest_cache/
51
+
52
+ # Translations
53
+ *.mo
54
+ *.pot
55
+
56
+ # Django stuff:
57
+ *.log
58
+ local_settings.py
59
+ db.sqlite3
60
+ db.sqlite3-journal
61
+
62
+ # Flask stuff:
63
+ instance/
64
+ .webassets-cache
65
+
66
+ # Scrapy stuff:
67
+ .scrapy
68
+
69
+ # Sphinx documentation
70
+ docs/_build/
71
+
72
+ # PyBuilder
73
+ target/
74
+
75
+ # Jupyter Notebook
76
+ .ipynb_checkpoints
77
+
78
+ # IPython
79
+ profile_default/
80
+ ipython_config.py
81
+
82
+ # pyenv
83
+ .python-version
84
+
85
+ # pipenv
86
+ Pipfile.lock
87
+
88
+ # PEP 582
89
+ __pypackages__/
90
+
91
+ # Celery stuff
92
+ celerybeat-schedule
93
+ celerybeat.pid
94
+
95
+ # SageMath parsed files
96
+ *.sage.py
97
+
98
+ # Environments
99
+ .env
100
+ .venv
101
+ env/
102
+ venv/
103
+ ENV/
104
+ env.bak/
105
+ venv.bak/
106
+
107
+ # Spyder project settings
108
+ .spyderproject
109
+ .spyproject
110
+
111
+ # Rope project settings
112
+ .ropeproject
113
+
114
+ # mkdocs documentation
115
+ /site
116
+
117
+ # mypy
118
+ .mypy_cache/
119
+ .dmypy.json
120
+ dmypy.json
121
+
122
+ # Pyre type checker
123
+ .pyre/
124
+
125
+ # IDEs
126
+ .vscode/
127
+ .idea/
128
+ *.swp
129
+ *.swo
130
+ *~
131
+
132
+ # OS
133
+ .DS_Store
134
+ Thumbs.db
135
+
136
+ # Project-specific
137
+ checkpoints/
138
+ runs/
139
+ results/
140
+ *.pt
141
+ *.pth
142
+ *.safetensors
143
+ aim_logs/
144
+ .aim/
145
+ wandb/
146
+ lightning_logs/
147
+
148
+ # Dataset cache
149
+ .cache/
150
+ datasets_cache/
code/TaoTrain/README.md ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TaoTrain: Production-Grade LLM Training Framework
2
+
3
+ **TaoTrain** is a sophisticated PyTorch framework for training large language models at every scale—from experimental pretraining through supervised fine-tuning to reinforcement learning. Unlike fragmented training scripts or heavyweight frameworks, TaoTrain unifies the **entire training pipeline** in a clean, modular codebase that appeals to both ML engineers and software engineers.
4
+
5
+ ## Current Taotern Work
6
+
7
+ TaoTrain now includes the Taotern comparison architectures used by the current SSM LLM work:
8
+
9
+ - `taonet`: the attention/MLA baseline.
10
+ - `taonet_ssm`: the TaoNet shell with the attention mixer replaced by the Gamma Space Model DPLR SSM.
11
+ - `taonet_hybrid`: an alternating attention/SSM TaoNet used for the current best 200M-class candidate.
12
+
13
+ The current selected deployment-oriented run is `hybrid_ssm_first_199m`, a `199,480,928` parameter model with 16 layers: SSM layers at `0,2,4,6,8,10,12,14` and attention layers at `1,3,5,7,9,11,13,15`. It uses the DPLR SSM core with split two-lane mixing, channel gates, per-channel local shift, and the faster convolution path for long-sequence training.
14
+
15
+ Remote run `taotern-200m-hybrid-chat-20260512` trains this model on TaoData for a 4B-token base stage and then runs SFT so the final artifact can be loaded as a chat model. The trainable fixes added for this run are:
16
+
17
+ - Async JSONL iteration keeps polling while tokenization workers are alive instead of ending early after a temporary empty queue.
18
+ - Cached JSONL scan metadata is reused safely while recomputing chunk ranges for the active `samples_per_chunk` and `max_samples` settings.
19
+
20
+ ## Why TaoTrain?
21
+
22
+ - **Complete Unified Pipeline**: Pretraining → SFT → RL in a single, consistent framework. No context switching between different codebases or architectures.
23
+ - **Production-Grade Engineering**: Type-safe Pydantic configs, comprehensive checkpointing, AimStack integration, and proper gradient handling—not research code, but a framework you can deploy.
24
+ - **Extensibility Without Modification**: Register custom models, optimizers, schedulers, and datasets via decorators. Experiment freely without forking the framework.
25
+ - **Developer Experience First**: Interactive TUI for inference, intuitive YAML configurations, async data loading that eliminates I/O bottlenecks, and clear abstractions that make the codebase a pleasure to work with.
26
+
27
+ ## Key Capabilities
28
+
29
+ | Capability | Details |
30
+ |---|---|
31
+ | **Multi-Stage Training** | Unified infrastructure for pretraining, SFT, and RL. Share model checkpoints, logging, and evaluation across stages. |
32
+ | **Advanced Optimization** | Hybrid Muon + AdamW optimizer: efficient 2D weight updates via SVD-based methods + adaptive learning for 1D parameters. |
33
+ | **Modern Architectures** | DeepSeek MLA with grouped query attention (GQA), YaRN context extension, and factorized embeddings—all configurable via YAML. |
34
+ | **Production Features** | BF16 mixed precision training, gradient accumulation, proper gradient clipping, checkpoint resumption, and validation loops. |
35
+ | **Async Data Pipeline** | Background tokenization with multi-threaded workers. Stream billion-token datasets from JSONL without loading into memory. |
36
+ | **Interactive Inference** | TUI chat interface with real-time generation speed metrics and multi-model comparison. |
37
+ | **Logging & Monitoring** | AimStack integration tracks loss, metrics, hyperparameters, and git hashes for reproducibility. Visualize training runs in your browser. |
38
+
39
+ ## Getting Started
40
+
41
+ ### Installation
42
+
43
+ ```bash
44
+ git clone https://github.com/lobakkang/taoTrain.git
45
+ cd taoTrain
46
+ pip install -e .
47
+ ```
48
+
49
+ ### Training Examples
50
+
51
+ **Pretraining on a custom dataset:**
52
+ ```bash
53
+ train pretrain --config configs/pretrain.yaml
54
+ ```
55
+ Starts from scratch, learns representations from raw text via next-token prediction.
56
+
57
+ **Supervised Fine-tuning:**
58
+ ```bash
59
+ train sft --config configs/sft.yaml
60
+ ```
61
+ Fine-tune a pretrained model on instruction-response pairs for improved task performance.
62
+
63
+ **Reinforcement Learning (DPO):**
64
+ ```bash
65
+ train rl --config configs/rl_dpo.yaml
66
+ ```
67
+ Align models with human preferences using Direct Preference Optimization.
68
+
69
+ **Interactive Chat:**
70
+ ```bash
71
+ tui-chat --model checkpoints/model.pt
72
+ ```
73
+ Launch an interactive TUI to chat with your model and monitor generation metrics in real-time.
74
+
75
+ ### Configuration
76
+
77
+ All training is configured via YAML with Pydantic validation. Configs are type-safe and automatically validated:
78
+
79
+ ```yaml
80
+ # configs/sft.yaml
81
+ model:
82
+ architecture_type: "mla" # DeepSeek MLA with GQA
83
+ hidden_dim: 2048
84
+ num_layers: 24
85
+ num_heads: 32
86
+ d_latent_kv: 1536 # KV compression factor
87
+
88
+ training:
89
+ num_epochs: 3
90
+ batch_size: 32
91
+ learning_rate: 1e-4
92
+ warmup_ratio: 0.1
93
+ max_grad_norm: 1.0
94
+
95
+ optimizer:
96
+ optimizer_type: "muon_adamw" # Hybrid Muon + AdamW
97
+ muon_momentum: 0.95
98
+
99
+ data:
100
+ dataset_type: "sft_jsonl" # or "sft_hf" for HuggingFace
101
+ path: "data/sft_training.jsonl"
102
+
103
+ logging:
104
+ log_to_aim: true
105
+ aim_repo: "/tmp/aim_logs"
106
+ ```
107
+
108
+ See `configs/` for complete examples.
109
+
110
+ ## Project Architecture
111
+
112
+ ```
113
+ src/taoTrain/
114
+ ├── cli.py # Main CLI entry point
115
+ ├── config.py # Pydantic configuration schemas
116
+
117
+ ├── core/ # Base abstractions
118
+ │ └── base.py # BaseModel, BaseDataset, BaseTrainer
119
+
120
+ ├── models/ # Pluggable architecture system
121
+ │ ├── registry.py # Architecture factory with @register_architecture
122
+ │ ├── taonet.py # SimpleLLM with DeepSeek MLA
123
+ │ ├── mla_components.py # KV compression, GQA, YaRN
124
+ │ ├── embeddings.py # Factorized embeddings
125
+ │ └── transformer.py # Standard Transformer reference
126
+
127
+ ├── data/ # Advanced data pipeline
128
+ │ ├── factory.py # Dataset factory (HF + JSONL backends)
129
+ │ ├── async_loader.py # Async batch iteration (no I/O bottleneck)
130
+ │ ├── tokenization_queue.py # Background multi-threaded tokenization
131
+ │ ├── chunk_manager.py # Stream billion-token JSONL files
132
+ │ ├── hf_pretrain.py # HuggingFace pretraining datasets
133
+ │ ├── hf_sft.py # HuggingFace SFT datasets
134
+ │ ├── hf_rl.py # HuggingFace RL datasets
135
+ │ ├── pretrain_jsonl.py # JSONL pretraining
136
+ │ ├── sft_jsonl.py # JSONL SFT with instructions
137
+ │ └── rl_jsonl.py # JSONL RL with preferences
138
+
139
+ ├── training/ # Unified training infrastructure
140
+ │ └── trainer.py # Trainer + PretrainTrainer, SFTTrainer, RLTrainer
141
+
142
+ ├── optimizers/ # Pluggable optimizer system
143
+ │ ├── registry.py # Optimizer factory with @register_optimizer
144
+ │ ├── hybrid_muon_adamw.py # Composite: Muon (2D) + AdamW (1D)
145
+ │ ├── adamw.py # AdamW with weight decay
146
+ │ ├── adam.py # Standard Adam
147
+ │ └── sgd.py # SGD variants
148
+
149
+ ├── schedulers/ # Learning rate schedules
150
+ │ ├── registry.py # LR scheduler factory
151
+ │ ├── cosine_warmup.py # 3-phase: linear warmup → plateau → cosine decay
152
+ │ ├── linear_warmup.py # Linear warmup + constant
153
+ │ └── constant.py # Constant learning rate
154
+
155
+ ├── inference/ # Inference & interaction
156
+ │ ├── inferencer.py # Load & run inference from checkpoints
157
+ │ └── tui.py # Interactive chat with metrics display
158
+
159
+ ├── checkpointing/ # State management
160
+ │ └── checkpoint.py # Save/load model + optimizer + config + metrics
161
+
162
+ ├── logging/ # Experiment tracking
163
+ │ └── aim_logger.py # AimStack integration (loss, metrics, hyperparams)
164
+
165
+ ├── benchmarks/ # Evaluation tools
166
+ │ └── runner.py # Perplexity, speed, and task-specific benchmarks
167
+
168
+ └── utils/
169
+ └── helpers.py # Utility functions
170
+
171
+ configs/ # Example YAML configurations
172
+ ├── pretrain.yaml # Pretraining config
173
+ ├── sft.yaml # SFT config
174
+ ├── rl_dpo.yaml # RL/DPO config
175
+ └── tokenizer.yaml # Tokenizer config
176
+
177
+ tests/ # Unit & integration tests
178
+ └── test_dataset.py
179
+ ```
180
+
181
+ ## Extensible Architecture: The Registry Pattern
182
+
183
+ TaoTrain's power lies in its **pluggable design**. Add custom models, optimizers, schedulers, and datasets without modifying the framework.
184
+
185
+ ### Custom Model Architecture
186
+
187
+ ```python
188
+ from taoTrain.models import register_architecture, BaseModel
189
+ import torch.nn as nn
190
+
191
+ @register_architecture("custom_moe")
192
+ class MixtureOfExperts(BaseModel):
193
+ """Your custom MoE architecture"""
194
+ def __init__(self, config):
195
+ super().__init__(config)
196
+ self.experts = nn.ModuleList([
197
+ nn.Linear(config.hidden_dim, config.hidden_dim)
198
+ for _ in range(config.num_experts)
199
+ ])
200
+ self.router = nn.Linear(config.hidden_dim, config.num_experts)
201
+
202
+ def forward(self, input_ids, attention_mask=None):
203
+ # Your implementation
204
+ logits = self.compute_logits(input_ids)
205
+ loss = self.compute_loss(logits, labels) if labels is not None else None
206
+ return {"logits": logits, "loss": loss}
207
+ ```
208
+
209
+ Then use it in your config:
210
+
211
+ ```yaml
212
+ model:
213
+ architecture_type: "custom_moe"
214
+ hidden_dim: 2048
215
+ num_experts: 8
216
+ ```
217
+
218
+ ### Custom Optimizers & Schedulers
219
+
220
+ The same pattern works for optimizers and learning rate schedules:
221
+
222
+ ```python
223
+ from taoTrain.optimizers import register_optimizer
224
+ from torch.optim import Optimizer
225
+
226
+ @register_optimizer("my_adaptive_optimizer")
227
+ class MyAdaptiveOptimizer(Optimizer):
228
+ def step(self, closure=None):
229
+ # Your optimization logic
230
+ pass
231
+ ```
232
+
233
+ ```python
234
+ from taoTrain.schedulers import register_scheduler
235
+
236
+ @register_scheduler("my_schedule")
237
+ def my_schedule(initial_lr, step, total_steps, **kwargs):
238
+ return initial_lr * (1.0 - step / total_steps) # Linear decay
239
+ ```
240
+
241
+ **The key principle**: No framework code needs to change. You register once, it's available everywhere.
242
+
243
+ ### Dataset Backend Flexibility
244
+
245
+ Define custom datasets (JSONL, HF, streaming, etc.) and let the factory route to them:
246
+
247
+ ```python
248
+ from taoTrain.data import register_dataset
249
+
250
+ @register_dataset("pretrain", "my_backend")
251
+ class MyPretrainDataset(BaseDataset):
252
+ def __init__(self, config):
253
+ # Load from your custom backend
254
+ pass
255
+
256
+ def __getitem__(self, idx):
257
+ return {"input_ids": ..., "attention_mask": ...}
258
+ ```
259
+
260
+ Use in config:
261
+
262
+ ```yaml
263
+ data:
264
+ dataset_type: "pretrain"
265
+ backend_type: "my_backend" # Routes to MyPretrainDataset
266
+ ```
267
+
268
+ ## Why TaoTrain Framework?
269
+
270
+ ### Async Data Loading: No I/O Bottleneck
271
+
272
+ Most training frameworks load and tokenize data on the main training thread, blocking compute. TaoTrain's **multi-threaded tokenization pipeline**:
273
+
274
+ - Tokenizes data in background workers while your GPU trains
275
+ - Supports streaming billion-token JSONL files without loading into memory
276
+ - Intelligent chunking (by file size or sample count)
277
+ - Metadata caching to avoid rescanning
278
+
279
+ **Result**: 10-100x faster data iteration on large datasets.
280
+
281
+ ### Type-Safe Configuration
282
+
283
+ Forget YAML parsing errors or mysterious config bugs. TaoTrain uses **Pydantic dataclasses** for configuration:
284
+
285
+ - Automatic type validation: mistyped `learning_rate: "1e-4"` becomes an error, not silent failure
286
+ - Serialization: configs are part of checkpoints, ensuring reproducibility
287
+ - IDE support: autocomplete and type hints for all config fields
288
+ - Defaults: sensible defaults for all parameters
289
+
290
+ ### Benchmarking & Metrics
291
+
292
+ Track what matters:
293
+
294
+ - **Perplexity**: Language modeling quality on held-out data
295
+ - **Generation Speed**: Tokens-per-second (useful for TUI or deployment)
296
+ - **Task-Specific Accuracy**: Evaluate on downstream tasks
297
+ - **Training Metrics**: Loss curves, gradient norms, effective batch size
298
+
299
+ All logged to AimStack with git hashes for reproducibility.
300
+
301
+ ## Logging with AimStack
302
+
303
+ Automatically track and visualize experiments:
304
+
305
+ ```bash
306
+ aim up --host 0.0.0.0
307
+ ```
308
+
309
+ Then open `http://localhost:43800` to see:
310
+
311
+ - **Loss curves** per training step
312
+ - **Hyperparameters** (learning rate, batch size, model architecture)
313
+ - **Git hashes** for reproducibility
314
+ - **Custom metrics** (perplexity, validation accuracy, generation speed)
315
+ - **Compare runs**: Side-by-side experiment comparison
316
+
317
+ ## Advanced Features
318
+
319
+ ### Checkpointing with Resumption
320
+
321
+ TaoTrain saves complete training state:
322
+
323
+ ```python
324
+ checkpoint = {
325
+ "step": 12500,
326
+ "model_state": model.state_dict(),
327
+ "optimizer_state": optimizer.state_dict(),
328
+ "config": config, # Full config as Pydantic object
329
+ "metrics": metrics_tracker.to_dict(),
330
+ }
331
+ ```
332
+
333
+ Resume training from any checkpoint without loss of state. Keep last N checkpoints automatically.
334
+
335
+ ### Mixed Precision Training (BF16)
336
+
337
+ ```yaml
338
+ training:
339
+ use_bfloat16: true
340
+ gradient_accumulation_steps: 4
341
+ ```
342
+
343
+ - BF16 via `torch.autocast` for ~2x speedup with minimal accuracy loss
344
+ - Proper gradient scaling and clipping
345
+ - Compatible with all optimizers and architectures
346
+
347
+ ### 3-Phase Learning Rate Schedule
348
+
349
+ ```yaml
350
+ scheduler:
351
+ scheduler_type: "cosine_warmup"
352
+ warmup_ratio: 0.1 # 10% of training steps
353
+ steady_ratio: 0.5 # 50% at steady rate
354
+ min_lr_ratio: 0.1 # Final LR = 0.1 × initial_lr
355
+ num_cycles: 1
356
+ ```
357
+
358
+ This schedule:
359
+ 1. **Linear warmup** (0 → 1) over 10% of steps
360
+ 2. **Steady plateau** at full LR over 50% of steps
361
+ 3. **Cosine decay** (1 → 0.1) over remaining 40% of steps
362
+
363
+ Better convergence than simple cosine or linear decay.
364
+
365
+ ### Gradient Accumulation & Clipping
366
+
367
+ Simulate larger batch sizes with gradient accumulation:
368
+
369
+ ```yaml
370
+ training:
371
+ batch_size: 32
372
+ gradient_accumulation_steps: 4 # Effective batch = 128
373
+ max_grad_norm: 1.0 # Gradient clipping
374
+ ```
375
+
376
+ ## Contributing
377
+
378
+ Contributions are welcome! TaoTrain is designed to make contributions easy:
379
+
380
+ 1. **Add a model**: Implement `BaseModel` and `@register_architecture("name")`
381
+ 2. **Add an optimizer**: Implement `torch.optim.Optimizer` and `@register_optimizer("name")`
382
+ 3. **Add a dataset**: Implement `BaseDataset` and `@register_dataset(mode, backend_type)`
383
+ 4. **Improve the core**: Submit PRs to `training/`, `data/`, `logging/`, etc.
384
+
385
+ Ensure new code includes:
386
+ - Type hints throughout
387
+ - Pydantic configs for new parameters
388
+ - Unit tests in `tests/`
389
+ - Documentation in docstrings and README
390
+
391
+ ## Current Scope & Roadmap
392
+
393
+ ### ✅ Currently Supported
394
+
395
+ - **Single GPU / single node** training
396
+ - **Pretraining, SFT, and RL training** stages
397
+ - **HuggingFace and JSONL** data backends
398
+ - **BF16 mixed precision** training
399
+ - **Checkpoint saving/loading** with resumption
400
+ - **Interactive inference** via TUI
401
+ - **Benchmarking** (perplexity, speed)
402
+ - **Pluggable architectures, optimizers, schedulers, datasets**
403
+
404
+ ### 🚀 Roadmap (Future)
405
+
406
+ - **Distributed training** (DDP, FSDP) for multi-GPU/multi-node scaling
407
+ - **Quantization** support (INT8, QLoRA)
408
+ - **Advanced evaluation** (BLEU, ROUGE, custom tasks)
409
+ - **Streaming inference** with KV cache
410
+ - **Speculative decoding** for faster generation
411
+ - **Integration with popular model hubs** (Hugging Face Hub upload/download)
412
+
413
+ ---
414
+
415
+ ## Getting Help
416
+
417
+ - **Questions?** Open an issue on GitHub
418
+ - **Want to contribute?** See `CONTRIBUTING.md` (coming soon)
419
+ - **Found a bug?** Report it with a minimal reproduction script
420
+
421
+ ## License
422
+
423
+ MIT
code/TaoTrain/configs/pretrain.yaml ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TaoNet T2 Configuration for Pretraining
2
+ # DeepSeek MLA + RoPE with Hybrid Muon+AdamW Optimizer
3
+ # Full BF16 precision (no quantization)
4
+
5
+ # ============================================================================
6
+ # Model Architecture - TaoNet (DeepSeek MLA + RoPE)
7
+ # ============================================================================
8
+ model:
9
+ architecture_type: taonet
10
+ vocab_size: 8192
11
+ hidden_dim: 768
12
+ num_layers: 12
13
+ num_heads: 8
14
+ max_seq_length: 1024
15
+
16
+ # TaoNet-specific: Multi-head Latent Attention (MLA) compression
17
+ d_latent_kv: 512
18
+
19
+ # RoPE (Rotary Position Embedding) dimension per head
20
+ # Default would be 512 / 8 = 64
21
+ d_rope: 64
22
+
23
+ # Feed-forward intermediate dimension
24
+ hidden_dim_ff: 2048
25
+
26
+ # Dropout rate (low for stability with large models)
27
+ dropout: 0.02
28
+
29
+ # Grouped Query Attention (1 = standard MLA, >1 = GQA)
30
+ gqa_groups: 1
31
+
32
+ # Optional: Use factorized embedding for parameter efficiency
33
+ # vocab (8192) → rank (96) → hidden (512)
34
+ use_factorized_embedding: false
35
+ d_embed_rank: 96
36
+
37
+ # Weight initialization standard deviation
38
+ init_std: 0.02
39
+
40
+ # ============================================================================
41
+ # Dataset Configuration - Local JSONL
42
+ # ============================================================================
43
+ dataset:
44
+ local: true
45
+ jsonl_path: /home/student/Data/TaoData/output.jsonl
46
+ text_field: text
47
+ max_samples: 1000000
48
+ samples_per_chunk: 1000
49
+
50
+ # Tokenizer configuration
51
+ tokenizer_type: sentencepiece
52
+ tokenizer_path: tokenizer/tokenizer.model
53
+ tokenizer_threads: 4
54
+
55
+ # ============================================================================
56
+ # Training Hyperparameters
57
+ # ============================================================================
58
+ batch_size: 32
59
+ num_epochs: 2 # Set to 10 for full training
60
+ gradient_accumulation_steps: 8 # Effective batch: 32 × 8 = 256
61
+
62
+ # Maximum gradient norm for clipping (prevents ternary instability)
63
+ max_grad_norm: 1.0
64
+
65
+ # ============================================================================
66
+ # Optimizer - Hybrid Muon + AdamW
67
+ # ============================================================================
68
+ # Strategy:
69
+ # - Muon: For 2D Linear weight matrices (orthogonal/SVD-based optimization)
70
+ # - 2D weights: learning_rate (5e-3)
71
+ # - AdamW: For 1D parameters (biases, norms, embeddings)
72
+ # - 1D params: adamw_lr (5e-4) = 1/10 × learning_rate
73
+
74
+ optimizer:
75
+ optimizer_type: hybrid_muon_adamw
76
+
77
+ # Learning rate for Muon (2D weight matrices)
78
+ learning_rate: 5e-3
79
+
80
+ # Learning rate for AdamW (1D parameters)
81
+ # Typically 1/10 of learning_rate to prevent over-updating 1D params
82
+ adamw_lr: 5e-4
83
+
84
+ # L2 regularization (weight decay)
85
+ weight_decay: 0.01
86
+
87
+ # Adam betas
88
+ betas: [0.9, 0.999]
89
+
90
+ # Epsilon for numerical stability
91
+ eps: 1e-8
92
+
93
+ # ============================================================================
94
+ # Learning Rate Scheduler - 3-Phase Cosine with Warmup
95
+ # ============================================================================
96
+ # Phases:
97
+ # 1. Warmup: 0 → 1.0 (300 steps, ~1.4% of training)
98
+ # 2. Steady: 1.0 (constant for 5% of training)
99
+ # 3. Decay: 1.0 → 0.1 (cosine decay for remaining 95%)
100
+
101
+ scheduler:
102
+ scheduler_type: cosineWarmup
103
+ warmup_steps: 300
104
+ warmup_ratio: 0.0 # Ignored if warmup_steps > 0
105
+ steady_ratio: 0.05 # 5% of total training steps at peak LR
106
+ min_lr_ratio: 0.1 # Decay to 10% of peak LR
107
+ num_cycles: 0.5 # For compatibility (not used in 3-phase schedule)
108
+
109
+ # ============================================================================
110
+ # Data Type and Device
111
+ # ============================================================================
112
+ dtype: bfloat16 # Use BF16 for better convergence with large models
113
+ device: cuda # Use GPU for training
114
+
115
+ # ============================================================================
116
+ # Checkpointing and Validation
117
+ # ============================================================================
118
+ checkpoint_dir: checkpoints/test
119
+ save_every_steps: 81920
120
+ save_best_model: true
121
+ keep_last_n_checkpoints: 3
122
+
123
+ # Validation
124
+ eval_every_steps: 8192
125
+ eval_samples: 8000
126
+
127
+ # ============================================================================
128
+ # Logging
129
+ # ============================================================================
130
+ log_every_steps: 50
131
+ aim_repo: .aim
132
+
133
+ # ============================================================================
134
+ # Miscellaneous
135
+ # ============================================================================
136
+ seed: 42
137
+ num_workers: 0
138
+ pin_memory: true
code/TaoTrain/configs/rl_dpo.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Example configuration for RL training (RL stage assumes you have a reward model)
2
+
3
+ model:
4
+ architecture_type: transformer
5
+ vocab_size: 50257
6
+ hidden_dim: 256
7
+ num_layers: 8
8
+ num_heads: 8
9
+ dropout: 0.1
10
+ max_seq_length: 512
11
+ init_std: 0.02
12
+
13
+ dataset:
14
+ dataset_name: allenai/real_toxicity_prompts
15
+ split: train
16
+ prompt_column: text
17
+ max_samples: 2000
18
+ cache_dir: .cache/datasets
19
+ tokenizer_threads: 1 # Number of background threads for tokenization (1-32 recommended)
20
+
21
+ batch_size: 4
22
+ num_epochs: 1
23
+ gradient_accumulation_steps: 8
24
+ max_grad_norm: 0.5
25
+
26
+ optimizer:
27
+ optimizer_type: adamw
28
+ learning_rate: 1e-5
29
+ weight_decay: 0.0
30
+
31
+ scheduler:
32
+ scheduler_type: linearWarmup
33
+ warmup_steps: 50
34
+
35
+ dtype: bfloat16
36
+ device: cuda
37
+
38
+ checkpoint_dir: checkpoints/rl
39
+ save_every_steps: 100
40
+ save_best_model: false
41
+ keep_last_n_checkpoints: 2
42
+
43
+ eval_every_steps: 100
44
+ eval_samples: 100
45
+
46
+ log_every_steps: 10
47
+ aim_repo: .aim
48
+
49
+ # RL-specific settings
50
+ rl_method: ppo # or "dpo"
51
+ reward_model_path: checkpoints/reward_model.pt # Path to your reward model
52
+ ppo_epochs: 4
53
+ ppo_clip_ratio: 0.2
54
+ entropy_coeff: 0.01
55
+ value_loss_coeff: 1.0
56
+ generation_max_length: 256
57
+
58
+ seed: 42
59
+ num_workers: 0
60
+ pin_memory: true
code/TaoTrain/configs/sft.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Example configuration for supervised fine-tuning
2
+ # Uses TaoNet (MLA+RoPE) architecture loaded from pretrained checkpoint
3
+
4
+ # ============================================================================
5
+ # Model Architecture - TaoNet (DeepSeek MLA + RoPE)
6
+ # ============================================================================
7
+ model:
8
+ architecture_type: taonet
9
+ vocab_size: 8192
10
+ hidden_dim: 768
11
+ num_layers: 12
12
+ num_heads: 8
13
+ max_seq_length: 1024
14
+
15
+ # TaoNet-specific: Multi-head Latent Attention (MLA) compression
16
+ d_latent_kv: 512
17
+
18
+ # RoPE (Rotary Position Embedding) dimension per head
19
+ # Default would be 512 / 8 = 64
20
+ d_rope: 64
21
+
22
+ # Feed-forward intermediate dimension
23
+ hidden_dim_ff: 2048
24
+
25
+ # Dropout rate (low for stability with large models)
26
+ dropout: 0.02
27
+
28
+ # Grouped Query Attention (1 = standard MLA, >1 = GQA)
29
+ gqa_groups: 1
30
+
31
+ # Optional: Use factorized embedding for parameter efficiency
32
+ # vocab (8192) → rank (96) → hidden (512)
33
+ use_factorized_embedding: false
34
+ d_embed_rank: 96
35
+
36
+ # Weight initialization standard deviation
37
+ init_std: 0.02
38
+
39
+ dataset:
40
+ split: train
41
+ instruction_column: input
42
+ response_column: output
43
+
44
+ local: true
45
+ jsonl_path: /home/student/Data/TaoData/sft.jsonl
46
+ samples_per_chunk: 1000
47
+ #max_samples: 2000000
48
+ max_samples: 1000000
49
+ cache_dir: .cache/datasets
50
+ instruction_template: "{instruction}\n{response}"
51
+
52
+ # Tokenizer configuration
53
+ tokenizer_type: sentencepiece
54
+ tokenizer_path: tokenizer/tokenizer.model
55
+ tokenizer_threads: 4
56
+
57
+ # SFT-specific configuration (these fields are in SFTConfig)
58
+ checkpoint_path: "checkpoints/pretrain/final_model.pt"
59
+ user_token: "<user>"
60
+ assistant_token: "<assistant>"
61
+ response_loss_only: true
62
+
63
+ batch_size: 8
64
+ num_epochs: 1
65
+ gradient_accumulation_steps: 4
66
+ max_grad_norm: 1.0
67
+
68
+ optimizer:
69
+ optimizer_type: adamw
70
+ learning_rate: 5e-5 # Lower LR for fine-tuning (vs 5e-4 pretrain base, 5e-3 Muon)
71
+ weight_decay: 0.01
72
+
73
+ scheduler:
74
+ scheduler_type: linearWarmup
75
+ warmup_steps: 500 # Less aggressive warmup for fine-tuning
76
+
77
+ dtype: bfloat16
78
+ device: cuda
79
+
80
+ checkpoint_dir: checkpoints/sft
81
+ save_every_steps: 81920
82
+ save_best_model: true
83
+ keep_last_n_checkpoints: 2
84
+
85
+ eval_every_steps: 8192
86
+ eval_samples: 200
87
+
88
+ log_every_steps: 10
89
+ aim_repo: .aim
90
+
91
+ seed: 42
92
+ num_workers: 0
93
+ pin_memory: true
code/TaoTrain/configs/ssm_pretrain.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TaoNet-SSM T2 Configuration for Pretraining
2
+ # TaoNet shell with Gamma SSM replacing the MLA attention mixer.
3
+
4
+ model:
5
+ architecture_type: taonet_ssm
6
+ vocab_size: 8192
7
+ hidden_dim: 768
8
+ num_layers: 12
9
+ num_heads: 8
10
+ max_seq_length: 1024
11
+
12
+ d_latent_kv: 512
13
+ d_rope: 64
14
+ hidden_dim_ff: 2048
15
+ dropout: 0.02
16
+ gqa_groups: 1
17
+ use_factorized_embedding: false
18
+ d_embed_rank: 96
19
+ init_std: 0.02
20
+
21
+ ssm_core: dplr
22
+ ssm_hidden_dim: 512
23
+ ssm_mixer_dim: 256
24
+ ssm_rank: 1
25
+ ssm_max_low_rank_scale: 0.1
26
+ ssm_discretization: bilinear
27
+ ssm_kernel_mode: auto
28
+ ssm_kernel_threshold: 64
29
+ ssm_dt_min: 1e-3
30
+ ssm_dt_max: 1e-1
31
+ ssm_dt_init: 1e-2
32
+ ssm_use_d: true
33
+ ssm_activation: gelu
34
+ ssm_gate: true
35
+ ssm_input_gate: true
36
+ ssm_use_padding_mask: false
37
+ ssm_layer_scale_init: 0.1
38
+
39
+ dataset:
40
+ local: true
41
+ jsonl_path: /home/student/Data/TaoData/output.jsonl
42
+ text_field: text
43
+ max_samples: 1000000
44
+ samples_per_chunk: 1000
45
+
46
+ tokenizer_type: sentencepiece
47
+ tokenizer_path: tokenizer/tokenizer.model
48
+ tokenizer_threads: 4
49
+
50
+ batch_size: 32
51
+ num_epochs: 2
52
+ gradient_accumulation_steps: 8
53
+ max_grad_norm: 1.0
54
+
55
+ optimizer:
56
+ optimizer_type: hybrid_muon_adamw
57
+ learning_rate: 5e-3
58
+ adamw_lr: 5e-4
59
+ weight_decay: 0.01
60
+ betas: [0.9, 0.999]
61
+ eps: 1e-8
62
+
63
+ scheduler:
64
+ scheduler_type: cosineWarmup
65
+ warmup_steps: 300
66
+ warmup_ratio: 0.0
67
+ steady_ratio: 0.05
68
+ min_lr_ratio: 0.1
69
+ num_cycles: 0.5
70
+
71
+ dtype: bfloat16
72
+ device: cuda
73
+
74
+ checkpoint_dir: checkpoints/ssm_test
75
+ save_every_steps: 81920
76
+ save_best_model: true
77
+ keep_last_n_checkpoints: 3
78
+
79
+ eval_every_steps: 8192
80
+ eval_samples: 8000
81
+
82
+ log_every_steps: 50
83
+ aim_repo: .aim
84
+
85
+ seed: 42
86
+ num_workers: 0
87
+ pin_memory: true
code/TaoTrain/configs/tokenizer.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Example configuration for training a SentencePiece tokenizer from JSONL data
2
+
3
+ # Dataset source - JSONL file
4
+ jsonl_path: /home/student/Data/TaoData/output.jsonl
5
+ text_field: text # Field name in JSON for text data
6
+
7
+ # Tokenizer training parameters
8
+ vocab_size: 8192
9
+ model_type: unigram # SentencePiece model type: unigram, bpe, char, word
10
+ character_coverage: 0.9995
11
+
12
+ # Output configuration
13
+ output_dir: tokenizer
14
+ tokenizer_prefix: tokenizer
15
+
16
+ # Token ID configuration
17
+ unk_id: 0 # Unknown token ID
18
+ bos_id: 1 # Beginning of sentence token ID
19
+ eos_id: 2 # End of sentence token ID
20
+ pad_id: 3 # Padding token ID
21
+
22
+ # Custom special tokens
23
+ # These will be added to the vocabulary with explicit IDs
24
+ # Useful for control tokens like <think>, <user>, <assistant>, etc.
25
+ # Note: Use \n for newline token, \t for tab, etc.
26
+ special_tokens:
27
+ <PAD>: 3 # Padding (typically same as pad_id above)
28
+ <EOS>: 2 # End of sentence (typically same as eos_id above)
29
+ <BOS>: 1 # Beginning of sentence (typically same as bos_id above)
30
+ <UNK>: 0 # Unknown (typically same as unk_id above)
31
+ "\n": 4 # Newline token - quoted to preserve literal \n in YAML
32
+ <think>: 8 # Special token for chain-of-thought reasoning
33
+ <user>: 9 # User message token
34
+ <assistant>: 10 # Assistant message token
35
+ <image>: 11 # Image token for multimodal models
36
+
37
+ # Data sampling (optional)
38
+ # Set to a number to train on only the first N samples from the JSONL file
39
+ # Useful for quick testing or sub-sampling large datasets
40
+ # Omit or set to null to use entire file
41
+ max_samples: 1000000
42
+
43
+ # Optional metadata
44
+ tokenizer_name: tokenizer
code/TaoTrain/configs/tokenizer_taodata_pilot.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pilot SentencePiece tokenizer for remote TaoData benchmarks.
2
+ #
3
+ # This is intentionally smaller than full production tokenizer training so the
4
+ # model-comparison loop can validate attention vs SSM on realistic tokenization.
5
+
6
+ jsonl_path: /home/student/Data/TaoData/pretrain.jsonl.fineweb.jsonl
7
+ text_field: text
8
+
9
+ vocab_size: 8192
10
+ model_type: unigram
11
+ character_coverage: 0.9995
12
+
13
+ output_dir: /home/student/YouZheng/tokenizers/taodata_pilot_8k
14
+ tokenizer_prefix: tokenizer
15
+
16
+ unk_id: 0
17
+ bos_id: 1
18
+ eos_id: 2
19
+ pad_id: 3
20
+
21
+ max_samples: 20000
22
+ tokenizer_name: taodata_pilot_8k
code/TaoTrain/configs/yarn_pretrain.yaml ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TaoNet T2 Configuration for YaRN Continued Pretraining
2
+ # Extended Context: 1024 → 8192 tokens with frequency interpolation
3
+ # Built on DeepSeek MLA + RoPE with Hybrid Muon+AdamW Optimizer
4
+ # Full BF16 precision (no quantization)
5
+
6
+ # ============================================================================
7
+ # Model Architecture - TaoNet (DeepSeek MLA + RoPE) with YaRN Extension
8
+ # ============================================================================
9
+ model:
10
+ architecture_type: taonet
11
+ vocab_size: 8192
12
+ hidden_dim: 512
13
+ num_layers: 12
14
+ num_heads: 8
15
+ max_seq_length: 8192 # Extended from 1024 → 8192 (8x longer context)
16
+
17
+ # TaoNet-specific: Multi-head Latent Attention (MLA) compression
18
+ # KV dimension reduced from 512 to 384 (25% memory savings)
19
+ d_latent_kv: 384
20
+
21
+ # RoPE (Rotary Position Embedding) dimension per head
22
+ # Default would be 512 / 8 = 64
23
+ d_rope: 64
24
+
25
+ # Feed-forward intermediate dimension
26
+ # Default would be 4 * 512 = 2048
27
+ hidden_dim_ff: 1024
28
+
29
+ # Dropout rate (low for stability)
30
+ dropout: 0.02
31
+
32
+ # Grouped Query Attention (1 = standard MLA, >1 = GQA)
33
+ gqa_groups: 1
34
+
35
+ # Optional: Use factorized embedding for parameter efficiency
36
+ use_factorized_embedding: false
37
+ d_embed_rank: 96
38
+
39
+ # Weight initialization standard deviation
40
+ init_std: 0.02
41
+
42
+ # ========================================================================
43
+ # YaRN (Yet another RoPE eXtension) Configuration
44
+ # ========================================================================
45
+ # Enables frequency interpolation to extend context length from 1024 → 8192
46
+ # The model learns to "pack" RoPE frequencies into the new longer context during training.
47
+
48
+ # RoPE base scale factor (explicit, previously hardcoded to 40)
49
+ rope_scale: 40.0
50
+
51
+ # Enable YaRN frequency interpolation
52
+ yarn_enabled: true
53
+
54
+ # Interpolation smoothness parameter
55
+ # - 1.0 (default): Smooth, gradual interpolation—safer for learning extended context
56
+ # - 0.5: Aggressive interpolation—faster context expansion, higher risk
57
+ # - 2.0: Conservative interpolation—safer but slower adaptation
58
+ # Recommendation: Start with 1.0; tune in follow-up runs if convergence issues
59
+ yarn_alpha: 1.0
60
+
61
+ # ============================================================================
62
+ # Dataset Configuration - Local JSONL (Same as Pretrain)
63
+ # ============================================================================
64
+ dataset:
65
+ local: true
66
+ jsonl_path: /home/student/Data/TaoData/output.jsonl
67
+ text_field: text
68
+ max_samples: 50000 # Reduced from 1M → 50k for quick YaRN adaptation
69
+ samples_per_chunk: 1000
70
+
71
+ # Tokenizer configuration (unchanged)
72
+ tokenizer_type: sentencepiece
73
+ tokenizer_path: tokenizer/tokenizer.model
74
+ tokenizer_threads: 4
75
+
76
+ # ============================================================================
77
+ # Training Hyperparameters - Conservative for Context Extension
78
+ # ============================================================================
79
+ # Strategy: Lower learning rates + smaller batch to prevent catastrophic forgetting
80
+ # while the model learns to use 8x longer context.
81
+
82
+ batch_size: 16 # Reduced from 32 (8192 tokens/seq is memory-intensive)
83
+ num_epochs: 1 # 50k samples / effective_batch=256 ≈ 200 updates (1 epoch sufficient for warm-start)
84
+
85
+ # Gradient accumulation to maintain effective batch size of ~256
86
+ # Effective batch = batch_size × gradient_accumulation_steps = 16 × 16 = 256
87
+ gradient_accumulation_steps: 16
88
+
89
+ # Maximum gradient norm for clipping
90
+ max_grad_norm: 1.0
91
+
92
+ # ============================================================================
93
+ # Optimizer - Hybrid Muon + AdamW (Conservative LR for Stability)
94
+ # ============================================================================
95
+ # Strategy: Use 1/2 of pretrain learning rates to:
96
+ # 1. Avoid catastrophic forgetting of learned features
97
+ # 2. Allow smooth adaptation to YaRN-scaled RoPE frequencies
98
+ # 3. Give the model time to learn how to use extended context
99
+
100
+ optimizer:
101
+ optimizer_type: hybrid_muon_adamw
102
+
103
+ # Learning rate for Muon (2D weight matrices)
104
+ # Reduced: 5e-3 → 2.5e-3 (50% of pretrain)
105
+ learning_rate: 2.5e-3
106
+
107
+ # Learning rate for AdamW (1D parameters)
108
+ # Reduced: 5e-4 → 1.25e-4 (25% of pretrain, maintains 1/10 ratio)
109
+ adamw_lr: 1.25e-4
110
+
111
+ # L2 regularization (weight decay)
112
+ weight_decay: 0.01
113
+
114
+ # Adam betas (unchanged)
115
+ betas: [0.9, 0.999]
116
+
117
+ # Epsilon for numerical stability
118
+ eps: 1e-8
119
+
120
+ # ============================================================================
121
+ # Learning Rate Scheduler - 3-Phase Cosine with Warmup (Same as Pretrain)
122
+ # ============================================================================
123
+ # Phases:
124
+ # 1. Warmup: 0 → 1.0 (300 steps, ~1.4% of training)
125
+ # 2. Steady: 1.0 (constant for 5% of training steps at peak LR)
126
+ # 3. Decay: 1.0 → 0.1 (cosine decay for remaining ~95%)
127
+
128
+ scheduler:
129
+ scheduler_type: cosineWarmup
130
+ warmup_steps: 300
131
+ warmup_ratio: 0.0 # Ignored if warmup_steps > 0
132
+ steady_ratio: 0.05 # 5% of total training steps at peak LR
133
+ min_lr_ratio: 0.1 # Decay to 10% of peak LR
134
+ num_cycles: 0.5 # For compatibility (not used in 3-phase schedule)
135
+
136
+ # ============================================================================
137
+ # Data Type and Device
138
+ # ============================================================================
139
+ dtype: bfloat16 # Use BF16 for better convergence with extended context
140
+ device: cuda # Use GPU for training
141
+
142
+ # ============================================================================
143
+ # Checkpointing and Validation
144
+ # ============================================================================
145
+ # Load pretrained checkpoint and continue training
146
+ checkpoint_path: checkpoints/pretrain_taonet/best_model.pt
147
+ checkpoint_dir: checkpoints/yarn_taonet
148
+ save_every_steps: 512 # More frequent saves for 50k samples (200 updates total)
149
+ save_best_model: true
150
+ keep_last_n_checkpoints: 3
151
+
152
+ # Validation every 512 steps (10% of 50k samples)
153
+ eval_every_steps: 512
154
+ eval_samples: 2500 # Reduced from 8000
155
+
156
+ # ============================================================================
157
+ # Logging
158
+ # ============================================================================
159
+ log_every_steps: 50 # Log every 50 updates
160
+ aim_repo: .aim
161
+
162
+ # ============================================================================
163
+ # Miscellaneous
164
+ # ============================================================================
165
+ seed: 42
166
+ num_workers: 0
167
+ pin_memory: true
168
+
169
+ # ============================================================================
170
+ # YaRN Performance Notes
171
+ # ============================================================================
172
+ # Expected memory usage: ~1.5x of pretrain (8x longer seq, half batch)
173
+ # Expected training time: ~50-100 steps/min on H100 (depends on setup)
174
+ # Expected convergence: Loss should decrease over 50k samples; monitor perplexity on 8192-length sequences
175
+ #
176
+ # Tuning recommendations for iterative improvements:
177
+ # 1. If loss is unstable: Reduce learning_rate further (1.25e-3)
178
+ # 2. If loss plateaus quickly: Increase max_samples (100k-200k)
179
+ # 3. If memory OOM: Reduce batch_size to 8 (maintain grad_accum at 16)
180
+ # 4. To speed context expansion: Reduce yarn_alpha to 0.5 (more aggressive)
181
+ # 5. For safer training: Increase yarn_alpha to 2.0 (more conservative)
code/TaoTrain/data/sample_data.jsonl ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {"text": "The quick brown fox jumps over the lazy dog."}
2
+ {"text": "Python is a powerful programming language used for data science, machine learning, and web development."}
3
+ {"text": "Artificial intelligence and machine learning are transforming industries and creating new opportunities."}
4
+ {"text": "Natural language processing enables computers to understand and generate human language."}
5
+ {"text": "Deep learning models like transformers have revolutionized the field of artificial intelligence."}
6
+ {"text": "Transfer learning allows us to leverage pre-trained models to solve new tasks more efficiently."}
7
+ {"text": "The transformer architecture introduced attention mechanisms that became fundamental to modern NLP."}
8
+ {"text": "Language models trained on large corpora can perform impressive few-shot learning tasks."}
9
+ {"text": "Tokenization is a crucial preprocessing step in natural language processing pipelines."}
10
+ {"text": "SentencePiece is a language-independent tokenization algorithm that handles subword segmentation."}
code/TaoTrain/data/sample_sft_multi.jsonl ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {"turns": [{"user": "What is machine learning?", "assistant": "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed. It uses algorithms to identify patterns in data."}, {"user": "Can you give me some examples?", "assistant": "Sure! Examples include: 1) Email spam detection - learns to identify spam patterns, 2) Image recognition - learns to identify objects in images, 3) Recommendation systems - learns user preferences to suggest content, 4) Natural language processing - learns language patterns for tasks like translation."}]}
2
+ {"turns": [{"user": "Explain neural networks", "assistant": "Neural networks are computing systems inspired by biological neural networks. They consist of interconnected nodes (neurons) organized in layers. Each connection has a weight that gets adjusted during training to learn patterns from data."}, {"user": "What layers do they have?", "assistant": "Neural networks typically have three types of layers: 1) Input layer - receives raw data, 2) Hidden layers - process data through weighted connections and activation functions, 3) Output layer - produces the final prediction or classification. Deep networks have multiple hidden layers."}]}
3
+ {"turns": [{"user": "What is a transformer model?", "assistant": "A transformer is a neural network architecture that uses self-attention mechanisms to weigh the importance of different input elements. It processes all input tokens in parallel (unlike RNNs) and captures long-range dependencies effectively. Transformers form the basis of modern large language models like GPT and BERT."}, {"user": "How does attention work?", "assistant": "Attention computes relevance scores between query tokens and all other tokens (keys and values). It learns which tokens to focus on when processing each position. Multi-head attention applies this in parallel with multiple learned representations. This allows the model to attend to different types of relationships simultaneously."}, {"user": "Why is it better than RNNs?", "assistant": "Transformers are better than RNNs because: 1) Parallel processing - all tokens processed simultaneously instead of sequentially, 2) Better long-range dependencies - attention can directly connect distant tokens, 3) Easier to train - no vanishing gradient problems, 4) Scalability - transformers scale better to large datasets and models."}]}
4
+ {"turns": [{"user": "What is the difference between supervised and unsupervised learning?", "assistant": "Supervised learning uses labeled data (input-output pairs) to train models, like predicting house prices from features. Unsupervised learning finds patterns in unlabeled data, like clustering similar customers. Supervised learning is more accurate but requires labeled data; unsupervised is useful for exploration."}, {"user": "What about semi-supervised learning?", "assistant": "Semi-supervised learning uses both labeled and unlabeled data. It's useful when labeling is expensive. The unlabeled data helps improve performance by providing additional patterns. Common techniques include pseudo-labeling (labeling unlabeled data with the current model) and consistency regularization."}]}
code/TaoTrain/data/sample_sft_single.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"input": "What is Python?", "output": "Python is a high-level, interpreted programming language known for its simplicity and readability. It supports multiple programming paradigms including procedural, object-oriented, and functional programming."}
2
+ {"input": "Explain recursion", "output": "Recursion is a programming technique where a function calls itself to solve a problem by breaking it down into smaller subproblems. Base cases prevent infinite recursion. Example: factorial(n) = n * factorial(n-1) for n > 1, else 1."}
3
+ {"input": "What are decorators in Python?", "output": "Decorators are functions that modify or wrap other functions or classes without permanently changing their source code. They use the @ symbol and are executed when the function is defined. Example: @staticmethod, @property, or custom decorators."}
4
+ {"input": "How do you handle exceptions in Python?", "output": "Exceptions are handled using try-except-else-finally blocks. The try block contains code that might raise an exception, except catches specific exceptions, else runs if no exception occurred, and finally runs regardless. Custom exceptions can be created by inheriting from Exception."}
5
+ {"input": "What is the difference between lists and tuples?", "output": "Lists are mutable (can be changed after creation) while tuples are immutable (cannot be changed). Lists use square brackets [] and tuples use parentheses (). Both can contain mixed data types. Tuples are faster and can be used as dictionary keys."}
code/TaoTrain/pyproject.toml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "taoTrain"
7
+ version = "0.1.0"
8
+ description = "Clean, modular PyTorch LLM training framework with pluggable architectures, AimStack logging, and TUI inference"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = { text = "MIT" }
12
+ authors = [
13
+ { name = "Felix", email = "felix@example.com" }
14
+ ]
15
+
16
+ dependencies = [
17
+ "torch>=2.0.0",
18
+ "transformers>=4.30.0",
19
+ "datasets>=2.10.0",
20
+ "pydantic>=2.0.0",
21
+ "pydantic-settings>=2.0.0",
22
+ "aim>=3.15.0",
23
+ "click>=8.1.0",
24
+ "rich>=13.0.0",
25
+ "textual>=0.30.0",
26
+ "numpy>=1.24.0",
27
+ "tqdm>=4.65.0",
28
+ "sentencepiece>=0.1.99",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=7.4.0",
34
+ "pytest-cov>=4.1.0",
35
+ "pytest-xdist>=3.3.0",
36
+ "black>=23.7.0",
37
+ "ruff>=0.0.280",
38
+ "typing-extensions>=4.7.0",
39
+ ]
40
+
41
+ [project.scripts]
42
+ train = "taoTrain.cli:main"
43
+ train-tokenizer = "taoTrain.cli:train_tokenizer_command"
44
+ tui-chat = "taoTrain.inference.tui:main"
45
+
46
+ [tool.setuptools.packages.find]
47
+ where = ["src"]
48
+
49
+ [tool.setuptools.package-data]
50
+ taoTrain = ["configs/**/*.yaml"]
51
+
52
+ [tool.black]
53
+ line-length = 100
54
+ target-version = ["py310"]
55
+
56
+ [tool.ruff]
57
+ line-length = 100
58
+ target-version = "py310"
59
+ select = ["E", "F", "W", "I", "N", "UP", "RUF"]
60
+ ignore = ["E501"]
61
+
62
+ [tool.pytest.ini_options]
63
+ testpaths = ["tests"]
64
+ python_files = "test_*.py"
65
+ addopts = "--verbose"
code/TaoTrain/scripts/benchmark_taonet_real_tokens.py ADDED
@@ -0,0 +1,1034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Real-text token benchmark for TaoNet attention vs TaoNet-SSM.
2
+
3
+ This script consumes a JSONL or plain-text corpus, tokenizes it into one long
4
+ stream, samples contiguous next-token batches, and compares TaoNet variants
5
+ with identical outer dimensions. It is intentionally lighter than the full
6
+ trainer so it can be used in the RepoBridge iteration loop.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import hashlib
13
+ from contextlib import nullcontext
14
+ from contextlib import redirect_stdout
15
+ import csv
16
+ import io
17
+ import json
18
+ import math
19
+ import os
20
+ from pathlib import Path
21
+ import platform
22
+ import random
23
+ import sys
24
+ import time
25
+ from typing import Any, Iterable
26
+
27
+ import torch
28
+
29
+ REPO_ROOT = Path(__file__).resolve().parents[1]
30
+ SRC_ROOT = REPO_ROOT / "src"
31
+ if str(SRC_ROOT) not in sys.path:
32
+ sys.path.insert(0, str(SRC_ROOT))
33
+
34
+ from taoTrain.config import ModelConfig
35
+ from taoTrain.models import get_model
36
+
37
+
38
+ DTYPES = {
39
+ "float32": torch.float32,
40
+ "fp32": torch.float32,
41
+ "float16": torch.float16,
42
+ "fp16": torch.float16,
43
+ "bfloat16": torch.bfloat16,
44
+ "bf16": torch.bfloat16,
45
+ }
46
+
47
+ BYTE_PAD_ID = 0
48
+ BYTE_EOS_ID = 1
49
+ BYTE_UNK_ID = 2
50
+ BYTE_OFFSET = 3
51
+ BYTE_VOCAB_SIZE = 259
52
+
53
+
54
+ def parse_int_list(value: str) -> list[int]:
55
+ return [int(item.strip()) for item in value.split(",") if item.strip()]
56
+
57
+
58
+ def parse_float_list(value: str) -> list[float]:
59
+ return [float(item.strip()) for item in value.split(",") if item.strip()]
60
+
61
+
62
+ def parse_str_list(value: str) -> list[str]:
63
+ return [item.strip() for item in value.split(",") if item.strip()]
64
+
65
+
66
+ def stable_case_id(parts: dict[str, Any]) -> str:
67
+ payload = json.dumps(parts, sort_keys=True, separators=(",", ":"))
68
+ return hashlib.sha1(payload.encode("utf-8")).hexdigest()[:16]
69
+
70
+
71
+ def make_case_parts(
72
+ args: argparse.Namespace,
73
+ *,
74
+ architecture: str,
75
+ ssm_hidden_dim: int | None,
76
+ ssm_mixer_dim: int | None,
77
+ ssm_num_lanes: int | None,
78
+ hybrid_pattern: str | None,
79
+ ssm_gate_type: str | None,
80
+ learning_rate: float,
81
+ weight_decay: float,
82
+ batch_size: int,
83
+ ) -> dict[str, Any]:
84
+ uses_ssm = architecture in {"taonet_ssm", "taonet_hybrid"}
85
+ return {
86
+ "architecture": architecture,
87
+ "batch_size": batch_size,
88
+ "seq_len": args.seq_len,
89
+ "hidden_dim": args.hidden_dim,
90
+ "num_layers": args.num_layers,
91
+ "num_heads": args.num_heads,
92
+ "hidden_dim_ff": args.hidden_dim_ff,
93
+ "d_latent_kv": args.d_latent_kv,
94
+ "learning_rate": learning_rate,
95
+ "weight_decay": weight_decay,
96
+ "train_steps": args.train_steps,
97
+ "eval_batches": args.eval_batches,
98
+ "ssm_core": args.ssm_core if uses_ssm else None,
99
+ "ssm_hidden_dim": ssm_hidden_dim if uses_ssm else None,
100
+ "ssm_mixer_dim": ssm_mixer_dim if uses_ssm else None,
101
+ "ssm_num_lanes": ssm_num_lanes if uses_ssm else None,
102
+ "ssm_lane_mode": args.ssm_lane_mode if uses_ssm else None,
103
+ "ssm_split_mix": args.ssm_split_mix if uses_ssm else None,
104
+ "ssm_gate_type": ssm_gate_type if uses_ssm else None,
105
+ "ssm_branch_rms_norm": args.ssm_branch_rms_norm if uses_ssm else None,
106
+ "ssm_branch_clip_value": args.ssm_branch_clip_value if uses_ssm else None,
107
+ "block_residual_rms_norm": args.block_residual_rms_norm if uses_ssm else None,
108
+ "block_residual_rms_target": args.block_residual_rms_target if uses_ssm else None,
109
+ "block_residual_rms_cap": args.block_residual_rms_cap if uses_ssm else None,
110
+ "hybrid_pattern": hybrid_pattern if architecture == "taonet_hybrid" else None,
111
+ "hybrid_ssm_layers": args.hybrid_ssm_layers if architecture == "taonet_hybrid" else None,
112
+ }
113
+
114
+
115
+ def synchronize(device: torch.device) -> None:
116
+ if device.type == "cuda":
117
+ torch.cuda.synchronize(device)
118
+
119
+
120
+ def reset_memory(device: torch.device) -> None:
121
+ if device.type == "cuda":
122
+ torch.cuda.reset_peak_memory_stats(device)
123
+
124
+
125
+ def memory_stats(device: torch.device) -> dict[str, float | None]:
126
+ if device.type != "cuda":
127
+ return {"peak_allocated_mb": None, "peak_reserved_mb": None}
128
+ return {
129
+ "peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
130
+ "peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
131
+ }
132
+
133
+
134
+ def iter_texts(path: Path, *, text_field: str, max_records: int | None) -> Iterable[str]:
135
+ suffix = path.suffix.lower()
136
+ count = 0
137
+ with path.open("r", encoding="utf-8", errors="replace") as handle:
138
+ if suffix in {".jsonl", ".json"}:
139
+ for line in handle:
140
+ if max_records is not None and count >= max_records:
141
+ break
142
+ line = line.strip()
143
+ if not line:
144
+ continue
145
+ try:
146
+ record = json.loads(line)
147
+ except json.JSONDecodeError:
148
+ continue
149
+ text = record.get(text_field)
150
+ if isinstance(text, str) and text:
151
+ count += 1
152
+ yield text
153
+ else:
154
+ for line in handle:
155
+ if max_records is not None and count >= max_records:
156
+ break
157
+ line = line.rstrip("\n")
158
+ if line:
159
+ count += 1
160
+ yield line
161
+
162
+
163
+ def load_sentencepiece(path: Path):
164
+ import sentencepiece as spm
165
+
166
+ processor = spm.SentencePieceProcessor()
167
+ processor.load(str(path))
168
+ return processor
169
+
170
+
171
+ def encode_text(text: str, *, tokenizer_type: str, sentencepiece_processor=None) -> list[int]:
172
+ if tokenizer_type == "byte":
173
+ return [byte + BYTE_OFFSET for byte in text.encode("utf-8", errors="replace")] + [BYTE_EOS_ID]
174
+ if tokenizer_type == "sentencepiece":
175
+ if sentencepiece_processor is None:
176
+ raise ValueError("sentencepiece tokenizer requested without a processor.")
177
+ ids = list(sentencepiece_processor.encode(text, out_type=int))
178
+ eos_id = sentencepiece_processor.eos_id()
179
+ if eos_id >= 0:
180
+ ids.append(eos_id)
181
+ return ids
182
+ raise ValueError(f"Unsupported tokenizer type '{tokenizer_type}'.")
183
+
184
+
185
+ def load_token_stream(args: argparse.Namespace) -> tuple[torch.Tensor, int]:
186
+ path = Path(args.data_path)
187
+ if not path.exists():
188
+ raise FileNotFoundError(f"Data path not found: {path}")
189
+
190
+ sp = None
191
+ if args.tokenizer_type == "sentencepiece":
192
+ if not args.tokenizer_path:
193
+ raise ValueError("--tokenizer-path is required for sentencepiece tokenization.")
194
+ sp = load_sentencepiece(Path(args.tokenizer_path))
195
+ vocab_size = int(sp.vocab_size())
196
+ else:
197
+ vocab_size = BYTE_VOCAB_SIZE
198
+
199
+ tokens: list[int] = []
200
+ for text in iter_texts(path, text_field=args.text_field, max_records=args.max_records):
201
+ tokens.extend(encode_text(text, tokenizer_type=args.tokenizer_type, sentencepiece_processor=sp))
202
+ if args.max_tokens is not None and len(tokens) >= args.max_tokens:
203
+ tokens = tokens[: args.max_tokens]
204
+ break
205
+
206
+ if len(tokens) < args.seq_len + 2:
207
+ raise ValueError(
208
+ f"Need at least {args.seq_len + 2} tokens, found {len(tokens)} from {path}."
209
+ )
210
+ return torch.tensor(tokens, dtype=torch.long), vocab_size
211
+
212
+
213
+ def split_stream(tokens: torch.Tensor, eval_fraction: float) -> tuple[torch.Tensor, torch.Tensor]:
214
+ split = int(tokens.numel() * (1.0 - eval_fraction))
215
+ split = max(2, min(split, tokens.numel() - 2))
216
+ return tokens[:split].contiguous(), tokens[split:].contiguous()
217
+
218
+
219
+ def sample_batch(
220
+ stream: torch.Tensor,
221
+ *,
222
+ batch_size: int,
223
+ seq_len: int,
224
+ device: torch.device,
225
+ generator: torch.Generator,
226
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
227
+ max_start = stream.numel() - seq_len - 1
228
+ if max_start <= 0:
229
+ raise ValueError(f"Token stream too short for seq_len={seq_len}.")
230
+ starts = torch.randint(0, max_start, (batch_size,), generator=generator)
231
+ rows = [stream[int(start) : int(start) + seq_len + 1] for start in starts]
232
+ batch = torch.stack(rows, dim=0).to(device=device)
233
+ input_ids = batch[:, :-1].contiguous()
234
+ labels = batch[:, 1:].contiguous()
235
+ attention_mask = torch.ones_like(input_ids)
236
+ return input_ids, labels, attention_mask
237
+
238
+
239
+ def token_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float:
240
+ predictions = torch.argmax(logits, dim=-1)
241
+ correct = predictions == labels
242
+ return float(correct.sum().detach().cpu() / correct.numel())
243
+
244
+
245
+ def build_config(
246
+ args: argparse.Namespace,
247
+ *,
248
+ architecture: str,
249
+ vocab_size: int,
250
+ ssm_hidden_dim: int | None,
251
+ ssm_mixer_dim: int | None,
252
+ ssm_num_lanes: int | None,
253
+ hybrid_pattern: str | None,
254
+ ssm_gate_type: str | None,
255
+ ) -> ModelConfig:
256
+ uses_ssm = architecture in {"taonet_ssm", "taonet_hybrid"}
257
+ d_latent_kv = args.d_latent_kv if args.d_latent_kv is not None else int(args.hidden_dim * 0.75)
258
+ d_rope = args.d_rope if args.d_rope is not None else args.hidden_dim // args.num_heads
259
+ hidden_dim_ff = args.hidden_dim_ff if args.hidden_dim_ff is not None else args.hidden_dim * 4
260
+ return ModelConfig(
261
+ architecture_type=architecture,
262
+ vocab_size=vocab_size,
263
+ hidden_dim=args.hidden_dim,
264
+ num_layers=args.num_layers,
265
+ num_heads=args.num_heads,
266
+ max_seq_length=args.seq_len,
267
+ d_latent_kv=d_latent_kv,
268
+ d_rope=d_rope,
269
+ hidden_dim_ff=hidden_dim_ff,
270
+ dropout=args.dropout,
271
+ gqa_groups=args.gqa_groups,
272
+ rope_scale=args.rope_scale,
273
+ yarn_alpha=args.yarn_alpha,
274
+ init_std=args.init_std,
275
+ ssm_core=args.ssm_core,
276
+ ssm_hidden_dim=ssm_hidden_dim,
277
+ ssm_mixer_dim=ssm_mixer_dim,
278
+ ssm_num_lanes=ssm_num_lanes or args.ssm_num_lanes,
279
+ ssm_lane_combine=args.ssm_lane_combine,
280
+ ssm_lane_mode=args.ssm_lane_mode,
281
+ ssm_split_mix=args.ssm_split_mix,
282
+ ssm_rank=args.ssm_rank,
283
+ ssm_max_low_rank_scale=args.ssm_max_low_rank_scale,
284
+ ssm_finite_tail_correction=args.ssm_finite_tail_correction,
285
+ ssm_kernel_mode=args.ssm_kernel_mode,
286
+ ssm_kernel_threshold=args.ssm_kernel_threshold,
287
+ ssm_dt_min=args.ssm_dt_min,
288
+ ssm_dt_max=args.ssm_dt_max,
289
+ ssm_dt_init=args.ssm_dt_init,
290
+ ssm_use_padding_mask=False,
291
+ ssm_activation=args.ssm_activation,
292
+ ssm_gate=args.ssm_gate,
293
+ ssm_input_gate=args.ssm_input_gate,
294
+ ssm_gate_type=ssm_gate_type or args.ssm_gate_type,
295
+ ssm_layer_scale_init=args.ssm_layer_scale_init,
296
+ ssm_branch_rms_norm=args.ssm_branch_rms_norm if uses_ssm else False,
297
+ ssm_branch_clip_value=args.ssm_branch_clip_value if uses_ssm else None,
298
+ block_residual_rms_norm=args.block_residual_rms_norm if uses_ssm else False,
299
+ block_residual_rms_target=args.block_residual_rms_target,
300
+ block_residual_rms_cap=args.block_residual_rms_cap if uses_ssm else None,
301
+ ssm_local_shift=args.ssm_local_shift,
302
+ ssm_local_shift_init=args.ssm_local_shift_init,
303
+ ssm_local_shift_per_channel=args.ssm_local_shift_per_channel,
304
+ hybrid_pattern=hybrid_pattern or args.hybrid_pattern,
305
+ hybrid_ssm_layers=args.hybrid_ssm_layers or None,
306
+ )
307
+
308
+
309
+ def time_repeats(fn, *, device: torch.device, warmup: int, repeats: int) -> tuple[float, float, float]:
310
+ last_loss = float("nan")
311
+ for _ in range(warmup):
312
+ last_loss = fn()
313
+ synchronize(device)
314
+
315
+ latencies = []
316
+ for _ in range(repeats):
317
+ reset_memory(device)
318
+ synchronize(device)
319
+ start = time.perf_counter()
320
+ last_loss = fn()
321
+ synchronize(device)
322
+ latencies.append(time.perf_counter() - start)
323
+ return sum(latencies) / len(latencies), min(latencies), last_loss
324
+
325
+
326
+ def evaluate_model(
327
+ model: torch.nn.Module,
328
+ *,
329
+ eval_stream: torch.Tensor,
330
+ args: argparse.Namespace,
331
+ batch_size: int,
332
+ device: torch.device,
333
+ generator: torch.Generator,
334
+ autocast_context,
335
+ ) -> tuple[float, float]:
336
+ model.eval()
337
+ losses = []
338
+ accuracies = []
339
+ with torch.no_grad():
340
+ for _ in range(args.eval_batches):
341
+ input_ids, labels, attention_mask = sample_batch(
342
+ eval_stream,
343
+ batch_size=batch_size,
344
+ seq_len=args.seq_len,
345
+ device=device,
346
+ generator=generator,
347
+ )
348
+ with autocast_context():
349
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
350
+ losses.append(float(outputs["loss"].detach().cpu()))
351
+ accuracies.append(token_accuracy(outputs["logits"], labels))
352
+ model.train()
353
+ return sum(losses) / len(losses), sum(accuracies) / len(accuracies)
354
+
355
+
356
+ def train_model(
357
+ model: torch.nn.Module,
358
+ *,
359
+ train_stream: torch.Tensor,
360
+ args: argparse.Namespace,
361
+ learning_rate: float,
362
+ weight_decay: float,
363
+ batch_size: int,
364
+ device: torch.device,
365
+ generator: torch.Generator,
366
+ autocast_context,
367
+ ) -> tuple[float | None, float | None, float | None, float | None, list[dict[str, float | int]]]:
368
+ if args.train_steps <= 0:
369
+ return None, None, None, None, []
370
+
371
+ model.train()
372
+ optimizer = torch.optim.AdamW(
373
+ model.parameters(),
374
+ lr=learning_rate,
375
+ weight_decay=weight_decay,
376
+ )
377
+ last_loss = float("nan")
378
+ last_grad_norm: float | None = None
379
+ max_grad_norm_seen: float | None = None
380
+ history: list[dict[str, float | int]] = []
381
+ start = time.perf_counter()
382
+ trainable_params = [param for param in model.parameters() if param.requires_grad]
383
+ log_every = max(0, int(args.train_log_every))
384
+ for step in range(1, args.train_steps + 1):
385
+ input_ids, labels, attention_mask = sample_batch(
386
+ train_stream,
387
+ batch_size=batch_size,
388
+ seq_len=args.seq_len,
389
+ device=device,
390
+ generator=generator,
391
+ )
392
+ optimizer.zero_grad(set_to_none=True)
393
+ with autocast_context():
394
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
395
+ loss = outputs["loss"]
396
+ loss.backward()
397
+ if args.max_grad_norm and args.max_grad_norm > 0:
398
+ grad_norm_tensor = torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
399
+ last_grad_norm = float(grad_norm_tensor.detach().cpu())
400
+ max_grad_norm_seen = (
401
+ last_grad_norm
402
+ if max_grad_norm_seen is None
403
+ else max(max_grad_norm_seen, last_grad_norm)
404
+ )
405
+ optimizer.step()
406
+ last_loss = float(loss.detach().cpu())
407
+ if log_every and (step == 1 or step % log_every == 0 or step == args.train_steps):
408
+ history.append(
409
+ {
410
+ "step": step,
411
+ "loss": last_loss,
412
+ "grad_norm": last_grad_norm if last_grad_norm is not None else float("nan"),
413
+ }
414
+ )
415
+ synchronize(device)
416
+ return last_loss, time.perf_counter() - start, last_grad_norm, max_grad_norm_seen, history
417
+
418
+
419
+ def save_case_checkpoint(
420
+ *,
421
+ args: argparse.Namespace,
422
+ model: torch.nn.Module,
423
+ config: ModelConfig,
424
+ case_id: str,
425
+ case_parts: dict[str, Any],
426
+ train_final_loss: float | None,
427
+ train_seconds: float | None,
428
+ train_history: list[dict[str, float | int]],
429
+ eval_loss: float,
430
+ eval_accuracy: float,
431
+ ) -> str | None:
432
+ if not args.save_case_checkpoints:
433
+ return None
434
+ checkpoint_dir = Path(args.checkpoint_dir) if args.checkpoint_dir else Path(args.output_dir) / "checkpoints"
435
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
436
+ path = checkpoint_dir / f"{case_id}.pt"
437
+ checkpoint = {
438
+ "step": args.train_steps,
439
+ "model_state": model.state_dict(),
440
+ "optimizer_state": None,
441
+ "config": {"model": config.to_dict(), "benchmark_args": vars(args)},
442
+ "metrics": {
443
+ "train_final_loss": train_final_loss,
444
+ "train_seconds": train_seconds,
445
+ "train_history": train_history,
446
+ "eval_loss": eval_loss,
447
+ "eval_accuracy": eval_accuracy,
448
+ },
449
+ "case_id": case_id,
450
+ "case_parts": case_parts,
451
+ }
452
+ torch.save(checkpoint, path)
453
+ latest_path = checkpoint_dir / "latest.pt"
454
+ try:
455
+ torch.save(checkpoint, latest_path)
456
+ except Exception:
457
+ pass
458
+ return str(path)
459
+
460
+
461
+ def benchmark_case(
462
+ *,
463
+ args: argparse.Namespace,
464
+ architecture: str,
465
+ ssm_hidden_dim: int | None,
466
+ ssm_mixer_dim: int | None,
467
+ ssm_num_lanes: int | None,
468
+ hybrid_pattern: str | None,
469
+ ssm_gate_type: str | None,
470
+ learning_rate: float,
471
+ weight_decay: float,
472
+ vocab_size: int,
473
+ train_stream: torch.Tensor,
474
+ eval_stream: torch.Tensor,
475
+ batch_size: int,
476
+ dtype: torch.dtype,
477
+ device: torch.device,
478
+ ) -> list[dict[str, Any]]:
479
+ seed_offset = 0 if architecture == "taonet" else int(ssm_hidden_dim or 0) + int(ssm_mixer_dim or 0)
480
+ seed_offset += int(ssm_num_lanes or 0) * 31
481
+ if hybrid_pattern:
482
+ seed_offset += sum(ord(char) for char in hybrid_pattern)
483
+ if ssm_gate_type:
484
+ seed_offset += sum(ord(char) for char in ssm_gate_type)
485
+ case_seed = args.seed + seed_offset + batch_size
486
+ torch.manual_seed(case_seed)
487
+ if device.type == "cuda":
488
+ torch.cuda.manual_seed_all(case_seed)
489
+ cpu_generator = torch.Generator().manual_seed(case_seed)
490
+ case_parts = make_case_parts(
491
+ args,
492
+ architecture=architecture,
493
+ ssm_hidden_dim=ssm_hidden_dim,
494
+ ssm_mixer_dim=ssm_mixer_dim,
495
+ ssm_num_lanes=ssm_num_lanes,
496
+ hybrid_pattern=hybrid_pattern,
497
+ ssm_gate_type=ssm_gate_type,
498
+ learning_rate=learning_rate,
499
+ weight_decay=weight_decay,
500
+ batch_size=batch_size,
501
+ )
502
+ case_id = stable_case_id(case_parts)
503
+ config = build_config(
504
+ args,
505
+ architecture=architecture,
506
+ vocab_size=vocab_size,
507
+ ssm_hidden_dim=ssm_hidden_dim,
508
+ ssm_mixer_dim=ssm_mixer_dim,
509
+ ssm_num_lanes=ssm_num_lanes,
510
+ hybrid_pattern=hybrid_pattern,
511
+ ssm_gate_type=ssm_gate_type,
512
+ )
513
+ with redirect_stdout(io.StringIO()):
514
+ model = get_model(config, device=device)
515
+ model.train()
516
+
517
+ if device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}:
518
+ def autocast_context():
519
+ return torch.autocast(device_type="cuda", dtype=dtype, enabled=True)
520
+ else:
521
+ def autocast_context():
522
+ return nullcontext()
523
+
524
+ (
525
+ train_final_loss,
526
+ train_seconds,
527
+ train_last_grad_norm,
528
+ train_max_grad_norm,
529
+ train_history,
530
+ ) = train_model(
531
+ model,
532
+ train_stream=train_stream,
533
+ args=args,
534
+ learning_rate=learning_rate,
535
+ weight_decay=weight_decay,
536
+ batch_size=batch_size,
537
+ device=device,
538
+ generator=cpu_generator,
539
+ autocast_context=autocast_context,
540
+ )
541
+ eval_loss, eval_accuracy = evaluate_model(
542
+ model,
543
+ eval_stream=eval_stream,
544
+ args=args,
545
+ batch_size=batch_size,
546
+ device=device,
547
+ generator=cpu_generator,
548
+ autocast_context=autocast_context,
549
+ )
550
+ checkpoint_path = save_case_checkpoint(
551
+ args=args,
552
+ model=model,
553
+ config=config,
554
+ case_id=case_id,
555
+ case_parts=case_parts,
556
+ train_final_loss=train_final_loss,
557
+ train_seconds=train_seconds,
558
+ train_history=train_history,
559
+ eval_loss=eval_loss,
560
+ eval_accuracy=eval_accuracy,
561
+ )
562
+
563
+ input_ids, labels, attention_mask = sample_batch(
564
+ eval_stream,
565
+ batch_size=batch_size,
566
+ seq_len=args.seq_len,
567
+ device=device,
568
+ generator=cpu_generator,
569
+ )
570
+ tokens = batch_size * args.seq_len
571
+ total_params = sum(param.numel() for param in model.parameters())
572
+ trainable_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
573
+ rows: list[dict[str, Any]] = []
574
+ uses_ssm = architecture in {"taonet_ssm", "taonet_hybrid"}
575
+
576
+ def add_row(mode: str, mean_s: float, min_s: float, loss: float) -> None:
577
+ perplexity = math.exp(min(eval_loss, 20.0))
578
+ rows.append(
579
+ {
580
+ "architecture": architecture,
581
+ "case_id": case_id,
582
+ "checkpoint_path": checkpoint_path,
583
+ "ssm_core": args.ssm_core if uses_ssm else None,
584
+ "ssm_hidden_dim": ssm_hidden_dim if uses_ssm else None,
585
+ "ssm_mixer_dim": ssm_mixer_dim if uses_ssm else None,
586
+ "ssm_num_lanes": ssm_num_lanes if uses_ssm else None,
587
+ "ssm_lane_combine": args.ssm_lane_combine if uses_ssm else None,
588
+ "ssm_lane_mode": args.ssm_lane_mode if uses_ssm else None,
589
+ "ssm_split_mix": args.ssm_split_mix if uses_ssm else None,
590
+ "ssm_finite_tail_correction": args.ssm_finite_tail_correction if uses_ssm else None,
591
+ "ssm_gate_type": ssm_gate_type if uses_ssm else None,
592
+ "ssm_local_shift": args.ssm_local_shift if uses_ssm else None,
593
+ "ssm_local_shift_per_channel": args.ssm_local_shift_per_channel if uses_ssm else None,
594
+ "ssm_branch_rms_norm": args.ssm_branch_rms_norm if uses_ssm else None,
595
+ "ssm_branch_clip_value": args.ssm_branch_clip_value if uses_ssm else None,
596
+ "block_residual_rms_norm": args.block_residual_rms_norm if uses_ssm else None,
597
+ "block_residual_rms_target": args.block_residual_rms_target if uses_ssm else None,
598
+ "block_residual_rms_cap": args.block_residual_rms_cap if uses_ssm else None,
599
+ "hybrid_pattern": hybrid_pattern if architecture == "taonet_hybrid" else None,
600
+ "hybrid_ssm_layers": args.hybrid_ssm_layers if architecture == "taonet_hybrid" else None,
601
+ "mode": mode,
602
+ "batch_size": batch_size,
603
+ "seq_len": args.seq_len,
604
+ "tokens": tokens,
605
+ "vocab_size": vocab_size,
606
+ "tokenizer_type": args.tokenizer_type,
607
+ "hidden_dim": args.hidden_dim,
608
+ "num_layers": args.num_layers,
609
+ "num_heads": args.num_heads,
610
+ "dtype": str(dtype).replace("torch.", ""),
611
+ "device": str(device),
612
+ "learning_rate": learning_rate,
613
+ "weight_decay": weight_decay,
614
+ "total_params": total_params,
615
+ "trainable_params": trainable_params,
616
+ "mean_ms": mean_s * 1000.0,
617
+ "min_ms": min_s * 1000.0,
618
+ "tokens_per_s_mean": tokens / max(mean_s, 1e-12),
619
+ "tokens_per_s_best": tokens / max(min_s, 1e-12),
620
+ "loss": loss,
621
+ "eval_loss": eval_loss,
622
+ "eval_perplexity": perplexity,
623
+ "eval_accuracy": eval_accuracy,
624
+ "train_final_loss": train_final_loss,
625
+ "train_seconds": train_seconds,
626
+ "train_history": json.dumps(train_history, separators=(",", ":")),
627
+ "train_last_grad_norm": train_last_grad_norm,
628
+ "train_max_grad_norm": train_max_grad_norm,
629
+ "max_grad_norm": args.max_grad_norm,
630
+ **memory_stats(device),
631
+ }
632
+ )
633
+
634
+ def forward_only() -> float:
635
+ with torch.no_grad():
636
+ with autocast_context():
637
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
638
+ return float(outputs["loss"].detach().cpu())
639
+
640
+ mean_s, min_s, loss = time_repeats(
641
+ forward_only,
642
+ device=device,
643
+ warmup=args.warmup,
644
+ repeats=args.repeats,
645
+ )
646
+ add_row("forward", mean_s, min_s, loss)
647
+
648
+ if args.backward:
649
+ def forward_backward() -> float:
650
+ model.zero_grad(set_to_none=True)
651
+ with autocast_context():
652
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
653
+ loss = outputs["loss"]
654
+ loss.backward()
655
+ return float(loss.detach().cpu())
656
+
657
+ mean_s, min_s, loss = time_repeats(
658
+ forward_backward,
659
+ device=device,
660
+ warmup=args.warmup,
661
+ repeats=args.repeats,
662
+ )
663
+ add_row("forward_backward", mean_s, min_s, loss)
664
+
665
+ return rows
666
+
667
+
668
+ def print_table(rows: list[dict[str, Any]]) -> None:
669
+ columns = [
670
+ "architecture",
671
+ "hybrid_pattern",
672
+ "ssm_hidden_dim",
673
+ "ssm_mixer_dim",
674
+ "ssm_num_lanes",
675
+ "ssm_lane_combine",
676
+ "ssm_lane_mode",
677
+ "ssm_split_mix",
678
+ "ssm_gate_type",
679
+ "ssm_branch_rms_norm",
680
+ "block_residual_rms_norm",
681
+ "block_residual_rms_cap",
682
+ "learning_rate",
683
+ "mode",
684
+ "batch_size",
685
+ "seq_len",
686
+ "mean_ms",
687
+ "tokens_per_s_mean",
688
+ "peak_allocated_mb",
689
+ "eval_loss",
690
+ "eval_perplexity",
691
+ "eval_accuracy",
692
+ "train_max_grad_norm",
693
+ ]
694
+ print("\t".join(columns))
695
+ for row in rows:
696
+ values = []
697
+ for column in columns:
698
+ value = row[column]
699
+ if isinstance(value, float):
700
+ values.append(f"{value:.3f}")
701
+ else:
702
+ values.append(str(value))
703
+ print("\t".join(values))
704
+
705
+
706
+ def write_outputs(rows: list[dict[str, Any]], output_dir: Path, metadata: dict[str, Any]) -> None:
707
+ output_dir.mkdir(parents=True, exist_ok=True)
708
+ json_path = output_dir / "taonet_real_token_benchmark.json"
709
+ csv_path = output_dir / "taonet_real_token_benchmark.csv"
710
+ json_path.write_text(json.dumps({"metadata": metadata, "results": rows}, indent=2), encoding="utf-8")
711
+
712
+ fieldnames = list(rows[0].keys()) if rows else []
713
+ with csv_path.open("w", newline="", encoding="utf-8") as handle:
714
+ writer = csv.DictWriter(handle, fieldnames=fieldnames)
715
+ writer.writeheader()
716
+ writer.writerows(rows)
717
+
718
+ print(f"Wrote {json_path}")
719
+ print(f"Wrote {csv_path}")
720
+
721
+
722
+ def load_existing_rows(output_dir: Path) -> list[dict[str, Any]]:
723
+ csv_path = output_dir / "taonet_real_token_benchmark.csv"
724
+ if not csv_path.exists():
725
+ return []
726
+ with csv_path.open("r", newline="", encoding="utf-8") as handle:
727
+ return list(csv.DictReader(handle))
728
+
729
+
730
+ def main() -> None:
731
+ parser = argparse.ArgumentParser(description="Benchmark TaoNet variants on a real text token stream.")
732
+ parser.add_argument("--data-path", required=True)
733
+ parser.add_argument("--text-field", default="text")
734
+ parser.add_argument("--tokenizer-type", choices=["byte", "sentencepiece"], default="byte")
735
+ parser.add_argument("--tokenizer-path", default="")
736
+ parser.add_argument("--max-records", type=int, default=None)
737
+ parser.add_argument("--max-tokens", type=int, default=1_000_000)
738
+ parser.add_argument("--eval-fraction", type=float, default=0.1)
739
+ parser.add_argument("--architectures", default="taonet,taonet_ssm")
740
+ parser.add_argument("--batch-sizes", default="8,16")
741
+ parser.add_argument("--seq-len", type=int, default=512)
742
+ parser.add_argument("--hidden-dim", type=int, default=256)
743
+ parser.add_argument("--num-layers", type=int, default=4)
744
+ parser.add_argument("--num-heads", type=int, default=4)
745
+ parser.add_argument("--d-latent-kv", type=int, default=None)
746
+ parser.add_argument("--d-rope", type=int, default=None)
747
+ parser.add_argument("--hidden-dim-ff", type=int, default=None)
748
+ parser.add_argument("--dropout", type=float, default=0.0)
749
+ parser.add_argument("--gqa-groups", type=int, default=1)
750
+ parser.add_argument("--rope-scale", type=float, default=40.0)
751
+ parser.add_argument("--yarn-alpha", type=float, default=1.0)
752
+ parser.add_argument("--init-std", type=float, default=0.02)
753
+ parser.add_argument("--ssm-core", choices=["gamma_s4", "dplr"], default="dplr")
754
+ parser.add_argument("--ssm-hidden-dims", default="16,64")
755
+ parser.add_argument("--ssm-mixer-dim", type=int, default=64)
756
+ parser.add_argument(
757
+ "--ssm-mixer-dims",
758
+ default="",
759
+ help="Comma-separated SSM mixer projection dimensions to sweep. Defaults to --ssm-mixer-dim.",
760
+ )
761
+ parser.add_argument("--ssm-num-lanes", type=int, default=1)
762
+ parser.add_argument(
763
+ "--ssm-num-lanes-list",
764
+ default="",
765
+ help="Comma-separated SSM lane counts to sweep for SSM-bearing architectures.",
766
+ )
767
+ parser.add_argument("--ssm-lane-combine", choices=["mean", "channel"], default="mean")
768
+ parser.add_argument("--ssm-lane-mode", choices=["full", "split"], default="full")
769
+ parser.add_argument(
770
+ "--ssm-lane-modes",
771
+ default="",
772
+ help="Comma-separated SSM lane modes to sweep for SSM-bearing architectures.",
773
+ )
774
+ parser.add_argument("--ssm-split-mix", choices=["none", "hadamard"], default="none")
775
+ parser.add_argument(
776
+ "--ssm-split-mixes",
777
+ default="",
778
+ help="Comma-separated split-lane cross mixers to sweep for SSM-bearing architectures.",
779
+ )
780
+ parser.add_argument("--ssm-rank", type=int, default=1)
781
+ parser.add_argument("--ssm-max-low-rank-scale", type=float, default=0.1)
782
+ parser.add_argument("--ssm-finite-tail-correction", action=argparse.BooleanOptionalAction, default=True)
783
+ parser.add_argument("--ssm-kernel-mode", choices=["auto", "conv", "conv_transfer", "recurrent"], default="conv")
784
+ parser.add_argument("--ssm-kernel-threshold", type=int, default=1)
785
+ parser.add_argument("--ssm-dt-min", type=float, default=1e-3)
786
+ parser.add_argument("--ssm-dt-max", type=float, default=1e-1)
787
+ parser.add_argument("--ssm-dt-init", type=float, default=1e-2)
788
+ parser.add_argument("--ssm-activation", choices=["gelu", "silu", "identity", "linear"], default="gelu")
789
+ parser.add_argument("--ssm-gate", action=argparse.BooleanOptionalAction, default=True)
790
+ parser.add_argument("--ssm-input-gate", action=argparse.BooleanOptionalAction, default=True)
791
+ parser.add_argument("--ssm-gate-type", choices=["dense", "channel"], default="dense")
792
+ parser.add_argument(
793
+ "--ssm-gate-types",
794
+ default="",
795
+ help="Comma-separated SSM gate implementations to sweep for SSM-bearing architectures.",
796
+ )
797
+ parser.add_argument("--ssm-layer-scale-init", type=float, default=0.1)
798
+ parser.add_argument("--ssm-branch-rms-norm", action=argparse.BooleanOptionalAction, default=False)
799
+ parser.add_argument("--ssm-branch-clip-value", type=float, default=None)
800
+ parser.add_argument("--block-residual-rms-norm", action=argparse.BooleanOptionalAction, default=False)
801
+ parser.add_argument("--block-residual-rms-target", type=float, default=1.0)
802
+ parser.add_argument("--block-residual-rms-cap", type=float, default=None)
803
+ parser.add_argument("--ssm-local-shift", action=argparse.BooleanOptionalAction, default=True)
804
+ parser.add_argument("--ssm-local-shift-init", type=float, default=0.1)
805
+ parser.add_argument("--ssm-local-shift-per-channel", action=argparse.BooleanOptionalAction, default=False)
806
+ parser.add_argument(
807
+ "--hybrid-pattern",
808
+ choices=["attention_first", "ssm_first", "single_ssm_middle", "single_ssm_late"],
809
+ default="attention_first",
810
+ )
811
+ parser.add_argument(
812
+ "--hybrid-patterns",
813
+ default="",
814
+ help="Comma-separated taonet_hybrid patterns to sweep. Ignored when --hybrid-ssm-layers is set.",
815
+ )
816
+ parser.add_argument(
817
+ "--hybrid-ssm-layers",
818
+ default="",
819
+ help="Comma-separated 0-based SSM layer indices for taonet_hybrid. Overrides pattern selection.",
820
+ )
821
+ parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16")
822
+ parser.add_argument("--device", default="auto")
823
+ parser.add_argument("--warmup", type=int, default=1)
824
+ parser.add_argument("--repeats", type=int, default=3)
825
+ parser.add_argument("--backward", action="store_true")
826
+ parser.add_argument("--train-steps", type=int, default=100)
827
+ parser.add_argument(
828
+ "--train-log-every",
829
+ type=int,
830
+ default=0,
831
+ help="Record train loss/grad history every N steps in CSV/JSON/checkpoints. 0 disables.",
832
+ )
833
+ parser.add_argument("--learning-rate", type=float, default=8e-4)
834
+ parser.add_argument(
835
+ "--max-grad-norm",
836
+ type=float,
837
+ default=0.0,
838
+ help="Clip training gradients to this norm during the benchmark. Use <=0 to disable.",
839
+ )
840
+ parser.add_argument(
841
+ "--learning-rates",
842
+ default="",
843
+ help="Comma-separated SSM learning rates to sweep. Attention uses --learning-rate once.",
844
+ )
845
+ parser.add_argument("--weight-decay", type=float, default=0.01)
846
+ parser.add_argument(
847
+ "--weight-decays",
848
+ default="",
849
+ help="Comma-separated SSM weight decays to sweep. Attention uses --weight-decay once.",
850
+ )
851
+ parser.add_argument("--eval-batches", type=int, default=8)
852
+ parser.add_argument("--seed", type=int, default=42)
853
+ parser.add_argument("--output-dir", default=os.environ.get("REPOBRIDGE_OUTPUT_DIR", "results/real-token-bench"))
854
+ parser.add_argument(
855
+ "--resume-completed",
856
+ action="store_true",
857
+ help="Skip benchmark cases whose case_id is already present in the output CSV.",
858
+ )
859
+ parser.add_argument(
860
+ "--incremental-output",
861
+ action=argparse.BooleanOptionalAction,
862
+ default=True,
863
+ help="Rewrite CSV/JSON outputs after each completed case so detached jobs leave usable partial results.",
864
+ )
865
+ parser.add_argument(
866
+ "--save-case-checkpoints",
867
+ action="store_true",
868
+ help="Save a model checkpoint after each trained benchmark case.",
869
+ )
870
+ parser.add_argument(
871
+ "--checkpoint-dir",
872
+ default="",
873
+ help="Directory for per-case checkpoints. Defaults to <output-dir>/checkpoints.",
874
+ )
875
+ args = parser.parse_args()
876
+
877
+ random.seed(args.seed)
878
+ torch.manual_seed(args.seed)
879
+ if args.device == "auto":
880
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
881
+ else:
882
+ device = torch.device(args.device)
883
+ dtype = DTYPES[args.dtype]
884
+ if device.type == "cuda":
885
+ torch.backends.cuda.matmul.allow_tf32 = True
886
+ torch.backends.cudnn.allow_tf32 = True
887
+
888
+ tokens, vocab_size = load_token_stream(args)
889
+ train_stream, eval_stream = split_stream(tokens, args.eval_fraction)
890
+ architectures = [item.strip() for item in args.architectures.split(",") if item.strip()]
891
+ ssm_hidden_dims = parse_int_list(args.ssm_hidden_dims)
892
+ ssm_mixer_dims = parse_int_list(args.ssm_mixer_dims) if args.ssm_mixer_dims else [args.ssm_mixer_dim]
893
+ ssm_num_lanes_options = parse_int_list(args.ssm_num_lanes_list) if args.ssm_num_lanes_list else [args.ssm_num_lanes]
894
+ ssm_lane_mode_options = parse_str_list(args.ssm_lane_modes) if args.ssm_lane_modes else [args.ssm_lane_mode]
895
+ ssm_split_mix_options = parse_str_list(args.ssm_split_mixes) if args.ssm_split_mixes else [args.ssm_split_mix]
896
+ ssm_learning_rates = parse_float_list(args.learning_rates) if args.learning_rates else [args.learning_rate]
897
+ ssm_weight_decays = parse_float_list(args.weight_decays) if args.weight_decays else [args.weight_decay]
898
+ hybrid_patterns = parse_str_list(args.hybrid_patterns) if args.hybrid_patterns else [args.hybrid_pattern]
899
+ ssm_gate_types = parse_str_list(args.ssm_gate_types) if args.ssm_gate_types else [args.ssm_gate_type]
900
+
901
+ output_dir = Path(args.output_dir)
902
+ rows: list[dict[str, Any]] = load_existing_rows(output_dir) if args.resume_completed else []
903
+ completed_case_ids = {
904
+ str(row.get("case_id"))
905
+ for row in rows
906
+ if row.get("mode") == "forward_backward" or not args.backward
907
+ }
908
+ for architecture in architectures:
909
+ hidden_options: list[int | None] = [None]
910
+ mixer_options: list[int | None] = [None]
911
+ learning_rate_options = [args.learning_rate]
912
+ weight_decay_options = [args.weight_decay]
913
+ hybrid_pattern_options: list[str | None] = [None]
914
+ gate_type_options: list[str | None] = [None]
915
+ lane_options: list[int | None] = [None]
916
+ lane_mode_options: list[str | None] = [None]
917
+ split_mix_options: list[str | None] = ["none"]
918
+ if architecture in {"taonet_ssm", "taonet_hybrid"}:
919
+ hidden_options = ssm_hidden_dims
920
+ mixer_options = ssm_mixer_dims
921
+ learning_rate_options = ssm_learning_rates
922
+ weight_decay_options = ssm_weight_decays
923
+ gate_type_options = ssm_gate_types
924
+ lane_options = ssm_num_lanes_options
925
+ lane_mode_options = ssm_lane_mode_options
926
+ split_mix_options = ssm_split_mix_options
927
+ if architecture == "taonet_hybrid":
928
+ hybrid_pattern_options = [args.hybrid_pattern] if args.hybrid_ssm_layers else hybrid_patterns
929
+ for ssm_hidden_dim in hidden_options:
930
+ for ssm_mixer_dim in mixer_options:
931
+ for ssm_num_lanes in lane_options:
932
+ for ssm_lane_mode in lane_mode_options:
933
+ for ssm_split_mix in split_mix_options:
934
+ if (
935
+ architecture in {"taonet_ssm", "taonet_hybrid"}
936
+ and ssm_lane_mode == "split"
937
+ and (ssm_num_lanes is None or ssm_num_lanes <= 1)
938
+ ):
939
+ continue
940
+ if (
941
+ architecture in {"taonet_ssm", "taonet_hybrid"}
942
+ and ssm_lane_mode == "split"
943
+ and ssm_mixer_dim is not None
944
+ and ssm_num_lanes is not None
945
+ and ssm_mixer_dim % ssm_num_lanes != 0
946
+ ):
947
+ print(
948
+ "Skipping split-lane case because ssm_mixer_dim is not divisible by "
949
+ f"ssm_num_lanes: mixer={ssm_mixer_dim} lanes={ssm_num_lanes}"
950
+ )
951
+ continue
952
+ if (
953
+ architecture in {"taonet_ssm", "taonet_hybrid"}
954
+ and ssm_lane_mode != "split"
955
+ and ssm_split_mix != "none"
956
+ ):
957
+ continue
958
+ if (
959
+ architecture in {"taonet_ssm", "taonet_hybrid"}
960
+ and ssm_split_mix == "hadamard"
961
+ and ssm_num_lanes != 2
962
+ ):
963
+ continue
964
+ args.ssm_lane_mode = ssm_lane_mode or args.ssm_lane_mode
965
+ args.ssm_split_mix = ssm_split_mix or args.ssm_split_mix
966
+ for hybrid_pattern in hybrid_pattern_options:
967
+ for ssm_gate_type in gate_type_options:
968
+ for learning_rate in learning_rate_options:
969
+ for weight_decay in weight_decay_options:
970
+ for batch_size in parse_int_list(args.batch_sizes):
971
+ print(
972
+ "Benchmarking "
973
+ f"architecture={architecture} ssm_hidden_dim={ssm_hidden_dim} "
974
+ f"ssm_mixer_dim={ssm_mixer_dim} ssm_num_lanes={ssm_num_lanes} "
975
+ f"ssm_lane_mode={args.ssm_lane_mode} ssm_split_mix={args.ssm_split_mix} "
976
+ f"hybrid_pattern={hybrid_pattern} ssm_gate_type={ssm_gate_type} "
977
+ f"lr={learning_rate} wd={weight_decay} batch={batch_size} seq={args.seq_len}"
978
+ )
979
+ case_parts = make_case_parts(
980
+ args,
981
+ architecture=architecture,
982
+ ssm_hidden_dim=ssm_hidden_dim,
983
+ ssm_mixer_dim=ssm_mixer_dim,
984
+ ssm_num_lanes=ssm_num_lanes,
985
+ hybrid_pattern=hybrid_pattern,
986
+ ssm_gate_type=ssm_gate_type,
987
+ learning_rate=learning_rate,
988
+ weight_decay=weight_decay,
989
+ batch_size=batch_size,
990
+ )
991
+ case_id = stable_case_id(case_parts)
992
+ if args.resume_completed and case_id in completed_case_ids:
993
+ print(f"Skipping completed case_id={case_id}")
994
+ continue
995
+ case_rows = benchmark_case(
996
+ args=args,
997
+ architecture=architecture,
998
+ ssm_hidden_dim=ssm_hidden_dim,
999
+ ssm_mixer_dim=ssm_mixer_dim,
1000
+ ssm_num_lanes=ssm_num_lanes,
1001
+ hybrid_pattern=hybrid_pattern,
1002
+ ssm_gate_type=ssm_gate_type,
1003
+ learning_rate=learning_rate,
1004
+ weight_decay=weight_decay,
1005
+ vocab_size=vocab_size,
1006
+ train_stream=train_stream,
1007
+ eval_stream=eval_stream,
1008
+ batch_size=batch_size,
1009
+ dtype=dtype,
1010
+ device=device,
1011
+ )
1012
+ rows.extend(case_rows)
1013
+ completed_case_ids.add(case_id)
1014
+ if args.incremental_output:
1015
+ write_outputs(rows, output_dir, {})
1016
+
1017
+ metadata = {
1018
+ "python": platform.python_version(),
1019
+ "platform": platform.platform(),
1020
+ "torch": torch.__version__,
1021
+ "cuda_available": torch.cuda.is_available(),
1022
+ "cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None,
1023
+ "args": vars(args),
1024
+ "vocab_size": vocab_size,
1025
+ "num_tokens": int(tokens.numel()),
1026
+ "train_tokens": int(train_stream.numel()),
1027
+ "eval_tokens": int(eval_stream.numel()),
1028
+ }
1029
+ print_table(rows)
1030
+ write_outputs(rows, output_dir, metadata)
1031
+
1032
+
1033
+ if __name__ == "__main__":
1034
+ main()
code/TaoTrain/scripts/benchmark_taonet_token_variants.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token-level benchmark for TaoNet attention vs TaoNet-SSM.
2
+
3
+ The goal is to compare the two LLM wrappers with the same outer dimensions:
4
+ original MLA attention TaoNet versus TaoNet with an SSM mixer.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ from contextlib import nullcontext
11
+ from contextlib import redirect_stdout
12
+ import csv
13
+ import io
14
+ import json
15
+ import os
16
+ from pathlib import Path
17
+ import platform
18
+ import subprocess
19
+ import sys
20
+ import time
21
+ from typing import Any
22
+
23
+ import torch
24
+
25
+ REPO_ROOT = Path(__file__).resolve().parents[1]
26
+ SRC_ROOT = REPO_ROOT / "src"
27
+ if str(SRC_ROOT) not in sys.path:
28
+ sys.path.insert(0, str(SRC_ROOT))
29
+
30
+ from taoTrain.config import ModelConfig
31
+ from taoTrain.models import get_model
32
+
33
+
34
+ DTYPES = {
35
+ "float32": torch.float32,
36
+ "fp32": torch.float32,
37
+ "float16": torch.float16,
38
+ "fp16": torch.float16,
39
+ "bfloat16": torch.bfloat16,
40
+ "bf16": torch.bfloat16,
41
+ }
42
+
43
+
44
+ def parse_int_list(value: str) -> list[int]:
45
+ return [int(item.strip()) for item in value.split(",") if item.strip()]
46
+
47
+
48
+ def synchronize(device: torch.device) -> None:
49
+ if device.type == "cuda":
50
+ torch.cuda.synchronize(device)
51
+
52
+
53
+ def reset_memory(device: torch.device) -> None:
54
+ if device.type == "cuda":
55
+ torch.cuda.reset_peak_memory_stats(device)
56
+
57
+
58
+ def memory_stats(device: torch.device) -> dict[str, float | None]:
59
+ if device.type != "cuda":
60
+ return {
61
+ "peak_allocated_mb": None,
62
+ "peak_reserved_mb": None,
63
+ }
64
+ return {
65
+ "peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
66
+ "peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
67
+ }
68
+
69
+
70
+ def nvidia_smi_snapshot() -> str | None:
71
+ try:
72
+ completed = subprocess.run(
73
+ [
74
+ "nvidia-smi",
75
+ "--query-gpu=name,memory.used,memory.total,utilization.gpu,utilization.memory,power.draw,temperature.gpu",
76
+ "--format=csv,noheader,nounits",
77
+ ],
78
+ check=False,
79
+ capture_output=True,
80
+ text=True,
81
+ timeout=5,
82
+ )
83
+ except (OSError, subprocess.TimeoutExpired):
84
+ return None
85
+ if completed.returncode != 0:
86
+ return None
87
+ return completed.stdout.strip()
88
+
89
+
90
+ def make_token_batch(
91
+ *,
92
+ batch_size: int,
93
+ seq_len: int,
94
+ vocab_size: int,
95
+ device: torch.device,
96
+ task: str = "random",
97
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
+ if task == "random":
99
+ input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
100
+ labels = torch.empty_like(input_ids)
101
+ labels[:, :-1] = input_ids[:, 1:]
102
+ labels[:, -1] = torch.randint(0, vocab_size, (batch_size,), device=device)
103
+ elif task == "increment":
104
+ starts = torch.randint(0, vocab_size, (batch_size, 1), device=device)
105
+ offsets = torch.arange(seq_len, device=device).view(1, seq_len)
106
+ input_ids = (starts + offsets) % vocab_size
107
+ labels = (input_ids + 1) % vocab_size
108
+ elif task == "previous":
109
+ input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
110
+ labels = torch.empty_like(input_ids)
111
+ labels[:, 0] = -100
112
+ labels[:, 1:] = input_ids[:, :-1]
113
+ else:
114
+ raise ValueError(f"Unsupported token task '{task}'.")
115
+ attention_mask = torch.ones_like(input_ids)
116
+ return input_ids, labels, attention_mask
117
+
118
+
119
+ def token_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float:
120
+ valid = labels != -100
121
+ if not torch.any(valid):
122
+ return float("nan")
123
+ predictions = torch.argmax(logits, dim=-1)
124
+ correct = (predictions == labels) & valid
125
+ return float(correct.sum().detach().cpu() / valid.sum().detach().cpu())
126
+
127
+
128
+ def build_config(args: argparse.Namespace, architecture: str) -> ModelConfig:
129
+ d_latent_kv = args.d_latent_kv if args.d_latent_kv is not None else int(args.hidden_dim * 0.75)
130
+ d_rope = args.d_rope if args.d_rope is not None else args.hidden_dim // args.num_heads
131
+ hidden_dim_ff = args.hidden_dim_ff if args.hidden_dim_ff is not None else args.hidden_dim * 4
132
+ return ModelConfig(
133
+ architecture_type=architecture,
134
+ vocab_size=args.vocab_size,
135
+ hidden_dim=args.hidden_dim,
136
+ num_layers=args.num_layers,
137
+ num_heads=args.num_heads,
138
+ max_seq_length=max(parse_int_list(args.seq_lens)),
139
+ d_latent_kv=d_latent_kv,
140
+ d_rope=d_rope,
141
+ hidden_dim_ff=hidden_dim_ff,
142
+ dropout=args.dropout,
143
+ gqa_groups=args.gqa_groups,
144
+ rope_scale=args.rope_scale,
145
+ yarn_alpha=args.yarn_alpha,
146
+ init_std=args.init_std,
147
+ ssm_core=args.ssm_core,
148
+ ssm_hidden_dim=args.ssm_hidden_dim or d_latent_kv,
149
+ ssm_mixer_dim=args.ssm_mixer_dim,
150
+ ssm_rank=args.ssm_rank,
151
+ ssm_max_low_rank_scale=args.ssm_max_low_rank_scale,
152
+ ssm_kernel_mode=args.ssm_kernel_mode,
153
+ ssm_kernel_threshold=args.ssm_kernel_threshold,
154
+ ssm_dt_min=args.ssm_dt_min,
155
+ ssm_dt_max=args.ssm_dt_max,
156
+ ssm_dt_init=args.ssm_dt_init,
157
+ ssm_use_padding_mask=args.ssm_use_padding_mask,
158
+ ssm_activation=args.ssm_activation,
159
+ ssm_gate=args.ssm_gate,
160
+ ssm_input_gate=args.ssm_input_gate,
161
+ ssm_layer_scale_init=args.ssm_layer_scale_init,
162
+ ssm_local_shift=args.ssm_local_shift,
163
+ ssm_local_shift_init=args.ssm_local_shift_init,
164
+ ssm_local_shift_per_channel=args.ssm_local_shift_per_channel,
165
+ )
166
+
167
+
168
+ def count_params(model: torch.nn.Module) -> tuple[int, int]:
169
+ total = sum(param.numel() for param in model.parameters())
170
+ trainable = sum(param.numel() for param in model.parameters() if param.requires_grad)
171
+ return total, trainable
172
+
173
+
174
+ def time_repeats(fn, *, device: torch.device, warmup: int, repeats: int) -> tuple[float, float, float]:
175
+ last_loss = float("nan")
176
+ for _ in range(warmup):
177
+ last_loss = fn()
178
+ synchronize(device)
179
+
180
+ latencies = []
181
+ for _ in range(repeats):
182
+ reset_memory(device)
183
+ synchronize(device)
184
+ start = time.perf_counter()
185
+ last_loss = fn()
186
+ synchronize(device)
187
+ latencies.append(time.perf_counter() - start)
188
+ return sum(latencies) / len(latencies), min(latencies), last_loss
189
+
190
+
191
+ def evaluate_model(
192
+ model: torch.nn.Module,
193
+ *,
194
+ args: argparse.Namespace,
195
+ batch_size: int,
196
+ seq_len: int,
197
+ device: torch.device,
198
+ autocast_context,
199
+ ) -> tuple[float, float]:
200
+ model.eval()
201
+ losses = []
202
+ accuracies = []
203
+ with torch.no_grad():
204
+ for _ in range(args.eval_batches):
205
+ input_ids, labels, attention_mask = make_token_batch(
206
+ batch_size=batch_size,
207
+ seq_len=seq_len,
208
+ vocab_size=args.vocab_size,
209
+ device=device,
210
+ task=args.token_task,
211
+ )
212
+ with autocast_context():
213
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
214
+ losses.append(float(outputs["loss"].detach().cpu()))
215
+ accuracies.append(token_accuracy(outputs["logits"], labels))
216
+ model.train()
217
+ return sum(losses) / len(losses), sum(accuracies) / len(accuracies)
218
+
219
+
220
+ def train_model(
221
+ model: torch.nn.Module,
222
+ *,
223
+ args: argparse.Namespace,
224
+ batch_size: int,
225
+ seq_len: int,
226
+ device: torch.device,
227
+ autocast_context,
228
+ ) -> tuple[float | None, float | None]:
229
+ if args.train_steps <= 0:
230
+ return None, None
231
+
232
+ model.train()
233
+ optimizer = torch.optim.AdamW(
234
+ model.parameters(),
235
+ lr=args.learning_rate,
236
+ weight_decay=args.weight_decay,
237
+ )
238
+ last_loss = float("nan")
239
+ start = time.perf_counter()
240
+ for _ in range(args.train_steps):
241
+ input_ids, labels, attention_mask = make_token_batch(
242
+ batch_size=batch_size,
243
+ seq_len=seq_len,
244
+ vocab_size=args.vocab_size,
245
+ device=device,
246
+ task=args.token_task,
247
+ )
248
+ optimizer.zero_grad(set_to_none=True)
249
+ with autocast_context():
250
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
251
+ loss = outputs["loss"]
252
+ loss.backward()
253
+ optimizer.step()
254
+ last_loss = float(loss.detach().cpu())
255
+ synchronize(device)
256
+ return last_loss, time.perf_counter() - start
257
+
258
+
259
+ def benchmark_case(
260
+ *,
261
+ args: argparse.Namespace,
262
+ architecture: str,
263
+ batch_size: int,
264
+ seq_len: int,
265
+ dtype: torch.dtype,
266
+ device: torch.device,
267
+ ) -> list[dict[str, Any]]:
268
+ config = build_config(args, architecture)
269
+ with redirect_stdout(io.StringIO()):
270
+ model = get_model(config, device=device)
271
+ model.train()
272
+ total_params, trainable_params = count_params(model)
273
+ tokens = batch_size * seq_len
274
+ input_ids, labels, attention_mask = make_token_batch(
275
+ batch_size=batch_size,
276
+ seq_len=seq_len,
277
+ vocab_size=args.vocab_size,
278
+ device=device,
279
+ task=args.token_task,
280
+ )
281
+ autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
282
+
283
+ def autocast_context():
284
+ if not autocast_enabled:
285
+ return nullcontext()
286
+ return torch.autocast(device_type=device.type, dtype=dtype, enabled=True)
287
+
288
+ train_final_loss, train_seconds = train_model(
289
+ model,
290
+ args=args,
291
+ batch_size=batch_size,
292
+ seq_len=seq_len,
293
+ device=device,
294
+ autocast_context=autocast_context,
295
+ )
296
+ eval_loss, eval_accuracy = evaluate_model(
297
+ model,
298
+ args=args,
299
+ batch_size=batch_size,
300
+ seq_len=seq_len,
301
+ device=device,
302
+ autocast_context=autocast_context,
303
+ )
304
+
305
+ rows: list[dict[str, Any]] = []
306
+
307
+ def add_row(mode: str, mean_s: float, min_s: float, loss: float) -> None:
308
+ rows.append(
309
+ {
310
+ "architecture": architecture,
311
+ "ssm_core": args.ssm_core if architecture == "taonet_ssm" else None,
312
+ "token_task": args.token_task,
313
+ "train_steps": args.train_steps,
314
+ "mode": mode,
315
+ "batch_size": batch_size,
316
+ "seq_len": seq_len,
317
+ "tokens": tokens,
318
+ "vocab_size": args.vocab_size,
319
+ "hidden_dim": args.hidden_dim,
320
+ "num_layers": args.num_layers,
321
+ "num_heads": args.num_heads,
322
+ "d_latent_kv": config.d_latent_kv,
323
+ "ssm_hidden_dim": config.ssm_hidden_dim if architecture == "taonet_ssm" else None,
324
+ "ssm_mixer_dim": config.ssm_mixer_dim if architecture == "taonet_ssm" else None,
325
+ "ssm_rank": config.ssm_rank if architecture == "taonet_ssm" else None,
326
+ "ssm_local_shift": config.ssm_local_shift if architecture == "taonet_ssm" else None,
327
+ "ssm_local_shift_init": config.ssm_local_shift_init if architecture == "taonet_ssm" else None,
328
+ "ssm_local_shift_per_channel": config.ssm_local_shift_per_channel if architecture == "taonet_ssm" else None,
329
+ "dtype": str(dtype).replace("torch.", ""),
330
+ "device": str(device),
331
+ "total_params": total_params,
332
+ "trainable_params": trainable_params,
333
+ "mean_ms": mean_s * 1000.0,
334
+ "min_ms": min_s * 1000.0,
335
+ "tokens_per_s_mean": tokens / max(mean_s, 1e-12),
336
+ "tokens_per_s_best": tokens / max(min_s, 1e-12),
337
+ "loss": loss,
338
+ "eval_loss": eval_loss,
339
+ "eval_accuracy": eval_accuracy,
340
+ "train_final_loss": train_final_loss,
341
+ "train_seconds": train_seconds,
342
+ **memory_stats(device),
343
+ }
344
+ )
345
+
346
+ def forward_only() -> float:
347
+ with torch.no_grad():
348
+ with autocast_context():
349
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
350
+ loss = outputs["loss"]
351
+ return float(loss.detach().cpu())
352
+
353
+ mean_s, min_s, loss = time_repeats(
354
+ forward_only,
355
+ device=device,
356
+ warmup=args.warmup,
357
+ repeats=args.repeats,
358
+ )
359
+ add_row("forward", mean_s, min_s, loss)
360
+
361
+ if args.backward:
362
+ def forward_backward() -> float:
363
+ model.zero_grad(set_to_none=True)
364
+ with autocast_context():
365
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
366
+ loss = outputs["loss"]
367
+ loss.backward()
368
+ return float(loss.detach().cpu())
369
+
370
+ mean_s, min_s, loss = time_repeats(
371
+ forward_backward,
372
+ device=device,
373
+ warmup=args.warmup,
374
+ repeats=args.repeats,
375
+ )
376
+ add_row("forward_backward", mean_s, min_s, loss)
377
+
378
+ return rows
379
+
380
+
381
+ def print_table(rows: list[dict[str, Any]]) -> None:
382
+ columns = [
383
+ "architecture",
384
+ "ssm_core",
385
+ "token_task",
386
+ "mode",
387
+ "batch_size",
388
+ "seq_len",
389
+ "mean_ms",
390
+ "tokens_per_s_mean",
391
+ "peak_allocated_mb",
392
+ "loss",
393
+ "eval_loss",
394
+ "eval_accuracy",
395
+ ]
396
+ print("\t".join(columns))
397
+ for row in rows:
398
+ values = []
399
+ for column in columns:
400
+ value = row[column]
401
+ if isinstance(value, float):
402
+ values.append(f"{value:.3f}")
403
+ else:
404
+ values.append(str(value))
405
+ print("\t".join(values))
406
+
407
+
408
+ def write_outputs(rows: list[dict[str, Any]], output_dir: Path, metadata: dict[str, Any]) -> None:
409
+ output_dir.mkdir(parents=True, exist_ok=True)
410
+ json_path = output_dir / "taonet_token_benchmark.json"
411
+ csv_path = output_dir / "taonet_token_benchmark.csv"
412
+ json_path.write_text(json.dumps({"metadata": metadata, "results": rows}, indent=2), encoding="utf-8")
413
+
414
+ fieldnames = list(rows[0].keys()) if rows else []
415
+ with csv_path.open("w", newline="", encoding="utf-8") as handle:
416
+ writer = csv.DictWriter(handle, fieldnames=fieldnames)
417
+ writer.writeheader()
418
+ writer.writerows(rows)
419
+
420
+ print(f"Wrote {json_path}")
421
+ print(f"Wrote {csv_path}")
422
+
423
+
424
+ def main() -> None:
425
+ parser = argparse.ArgumentParser(description="Benchmark TaoNet attention vs TaoNet-SSM on token batches.")
426
+ parser.add_argument("--architectures", default="taonet,taonet_ssm")
427
+ parser.add_argument("--batch-sizes", default="1,4")
428
+ parser.add_argument("--seq-lens", default="128,512")
429
+ parser.add_argument("--vocab-size", type=int, default=8192)
430
+ parser.add_argument("--hidden-dim", type=int, default=256)
431
+ parser.add_argument("--num-layers", type=int, default=4)
432
+ parser.add_argument("--num-heads", type=int, default=4)
433
+ parser.add_argument("--d-latent-kv", type=int, default=None)
434
+ parser.add_argument("--d-rope", type=int, default=None)
435
+ parser.add_argument("--hidden-dim-ff", type=int, default=None)
436
+ parser.add_argument("--dropout", type=float, default=0.0)
437
+ parser.add_argument("--gqa-groups", type=int, default=1)
438
+ parser.add_argument("--rope-scale", type=float, default=40.0)
439
+ parser.add_argument("--yarn-alpha", type=float, default=1.0)
440
+ parser.add_argument("--init-std", type=float, default=0.02)
441
+ parser.add_argument("--ssm-core", choices=["gamma_s4", "dplr"], default="dplr")
442
+ parser.add_argument("--ssm-hidden-dim", type=int, default=None)
443
+ parser.add_argument("--ssm-mixer-dim", type=int, default=None)
444
+ parser.add_argument("--ssm-rank", type=int, default=1)
445
+ parser.add_argument("--ssm-max-low-rank-scale", type=float, default=0.1)
446
+ parser.add_argument("--ssm-kernel-mode", choices=["auto", "conv", "conv_transfer", "recurrent"], default="conv")
447
+ parser.add_argument("--ssm-kernel-threshold", type=int, default=1)
448
+ parser.add_argument("--ssm-dt-min", type=float, default=1e-3)
449
+ parser.add_argument("--ssm-dt-max", type=float, default=1e-1)
450
+ parser.add_argument("--ssm-dt-init", type=float, default=1e-2)
451
+ parser.add_argument("--ssm-use-padding-mask", action="store_true")
452
+ parser.add_argument("--ssm-activation", choices=["gelu", "silu", "identity", "linear"], default="gelu")
453
+ parser.add_argument("--ssm-gate", action=argparse.BooleanOptionalAction, default=True)
454
+ parser.add_argument("--ssm-input-gate", action=argparse.BooleanOptionalAction, default=True)
455
+ parser.add_argument("--ssm-layer-scale-init", type=float, default=0.1)
456
+ parser.add_argument("--ssm-local-shift", action=argparse.BooleanOptionalAction, default=False)
457
+ parser.add_argument("--ssm-local-shift-init", type=float, default=0.1)
458
+ parser.add_argument("--ssm-local-shift-per-channel", action=argparse.BooleanOptionalAction, default=False)
459
+ parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16")
460
+ parser.add_argument("--device", default="auto")
461
+ parser.add_argument("--warmup", type=int, default=2)
462
+ parser.add_argument("--repeats", type=int, default=5)
463
+ parser.add_argument("--backward", action="store_true")
464
+ parser.add_argument("--token-task", choices=["random", "increment", "previous"], default="random")
465
+ parser.add_argument("--train-steps", type=int, default=0)
466
+ parser.add_argument("--learning-rate", type=float, default=3e-4)
467
+ parser.add_argument("--weight-decay", type=float, default=0.01)
468
+ parser.add_argument("--eval-batches", type=int, default=1)
469
+ parser.add_argument("--output-dir", default=os.environ.get("REPOBRIDGE_OUTPUT_DIR", "results/token-bench"))
470
+ args = parser.parse_args()
471
+
472
+ if args.device == "auto":
473
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
474
+ else:
475
+ device = torch.device(args.device)
476
+ dtype = DTYPES[args.dtype]
477
+ if device.type != "cuda" and dtype == torch.float16:
478
+ raise ValueError("float16 benchmark requires CUDA.")
479
+ if device.type == "cuda":
480
+ torch.backends.cuda.matmul.allow_tf32 = True
481
+ torch.backends.cudnn.allow_tf32 = True
482
+
483
+ architectures = [item.strip() for item in args.architectures.split(",") if item.strip()]
484
+ rows: list[dict[str, Any]] = []
485
+ metadata = {
486
+ "python": platform.python_version(),
487
+ "platform": platform.platform(),
488
+ "torch": torch.__version__,
489
+ "cuda_available": torch.cuda.is_available(),
490
+ "cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None,
491
+ "nvidia_smi_before": nvidia_smi_snapshot(),
492
+ "args": vars(args),
493
+ }
494
+
495
+ for architecture in architectures:
496
+ for batch_size in parse_int_list(args.batch_sizes):
497
+ for seq_len in parse_int_list(args.seq_lens):
498
+ print(f"Benchmarking architecture={architecture} batch={batch_size} seq={seq_len}")
499
+ rows.extend(
500
+ benchmark_case(
501
+ args=args,
502
+ architecture=architecture,
503
+ batch_size=batch_size,
504
+ seq_len=seq_len,
505
+ dtype=dtype,
506
+ device=device,
507
+ )
508
+ )
509
+
510
+ metadata["nvidia_smi_after"] = nvidia_smi_snapshot()
511
+ print_table(rows)
512
+ write_outputs(rows, Path(args.output_dir), metadata)
513
+
514
+
515
+ if __name__ == "__main__":
516
+ main()
code/TaoTrain/scripts/diagnostics/activation_probe.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Probe residual activation scale for a saved TaoTrain checkpoint."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ import re
8
+ import sys
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import torch
13
+
14
+ REPO_ROOT = Path(__file__).resolve().parents[2]
15
+ SRC_ROOT = REPO_ROOT / "src"
16
+ if str(SRC_ROOT) not in sys.path:
17
+ sys.path.insert(0, str(SRC_ROOT))
18
+
19
+ from taoTrain.checkpointing.checkpoint import CheckpointManager
20
+ from taoTrain.config import ModelConfig
21
+ from taoTrain.models import get_model
22
+
23
+
24
+ def load_sentencepiece(path: Path):
25
+ import sentencepiece as spm
26
+
27
+ processor = spm.SentencePieceProcessor()
28
+ processor.load(str(path))
29
+ return processor
30
+
31
+
32
+ def load_tokens(args: argparse.Namespace) -> tuple[torch.Tensor, int]:
33
+ tokenizer = load_sentencepiece(Path(args.tokenizer_path))
34
+ tokens: list[int] = []
35
+ with Path(args.data_path).open("r", encoding="utf-8", errors="replace") as handle:
36
+ for line in handle:
37
+ if len(tokens) >= args.max_tokens:
38
+ break
39
+ line = line.strip()
40
+ if not line:
41
+ continue
42
+ try:
43
+ record = json.loads(line)
44
+ except json.JSONDecodeError:
45
+ continue
46
+ text = record.get(args.text_field)
47
+ if not isinstance(text, str) or not text:
48
+ continue
49
+ ids = list(tokenizer.encode(text, out_type=int))
50
+ eos_id = tokenizer.eos_id()
51
+ if eos_id >= 0:
52
+ ids.append(eos_id)
53
+ tokens.extend(ids)
54
+ if len(tokens) < args.seq_len + 2:
55
+ raise ValueError(f"Need at least {args.seq_len + 2} tokens, got {len(tokens)}")
56
+ return torch.tensor(tokens[: args.max_tokens], dtype=torch.long), int(tokenizer.vocab_size())
57
+
58
+
59
+ def sample_batch(tokens: torch.Tensor, *, batch_size: int, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
60
+ max_start = tokens.numel() - seq_len - 1
61
+ starts = torch.linspace(0, max_start - 1, steps=batch_size).long()
62
+ rows = [tokens[int(start) : int(start) + seq_len + 1] for start in starts]
63
+ batch = torch.stack(rows, dim=0).to(device=device)
64
+ return batch[:, :-1].contiguous(), batch[:, 1:].contiguous()
65
+
66
+
67
+ def tensor_stats(value: torch.Tensor) -> dict[str, float | int]:
68
+ data = value.detach().float()
69
+ finite = torch.isfinite(data)
70
+ finite_count = int(finite.sum().cpu())
71
+ numel = data.numel()
72
+ if finite_count:
73
+ finite_data = data[finite]
74
+ rms = float(torch.sqrt(torch.mean(finite_data * finite_data)).cpu())
75
+ max_abs = float(finite_data.abs().max().cpu())
76
+ else:
77
+ rms = float("inf")
78
+ max_abs = float("inf")
79
+ return {
80
+ "numel": numel,
81
+ "finite": finite_count,
82
+ "rms": rms,
83
+ "max_abs": max_abs,
84
+ }
85
+
86
+
87
+ def main() -> None:
88
+ parser = argparse.ArgumentParser()
89
+ parser.add_argument("--checkpoint", required=True)
90
+ parser.add_argument("--tokenizer-path", required=True)
91
+ parser.add_argument("--data-path", required=True)
92
+ parser.add_argument("--text-field", default="text")
93
+ parser.add_argument("--output", required=True)
94
+ parser.add_argument("--batch-size", type=int, default=2)
95
+ parser.add_argument("--seq-len", type=int, default=512)
96
+ parser.add_argument("--max-tokens", type=int, default=200_000)
97
+ parser.add_argument("--device", default="cuda")
98
+ parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16")
99
+ args = parser.parse_args()
100
+
101
+ device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu")
102
+ dtype = {
103
+ "float32": torch.float32,
104
+ "bfloat16": torch.bfloat16,
105
+ "float16": torch.float16,
106
+ }[args.dtype]
107
+
108
+ tokens, _ = load_tokens(args)
109
+ input_ids, labels = sample_batch(tokens, batch_size=args.batch_size, seq_len=args.seq_len, device=device)
110
+ attention_mask = torch.ones_like(input_ids)
111
+
112
+ checkpoint_path = Path(args.checkpoint)
113
+ checkpoint = CheckpointManager(checkpoint_path.parent).load(checkpoint_path, device=device)
114
+ config_dict = checkpoint.get("config", {})
115
+ model_config = ModelConfig(**config_dict.get("model", {}))
116
+ model = get_model(model_config, device=device)
117
+ model.load_state_dict(checkpoint["model_state"], strict=False)
118
+ model.eval()
119
+
120
+ layer_stats: dict[str, dict[str, float | int]] = {}
121
+ handles = []
122
+ layer_pattern = re.compile(r"^(?:model\.)?(?:layers|blocks)\.\d+$")
123
+
124
+ def make_hook(name: str):
125
+ def hook(_module, _inputs, output):
126
+ value = output[0] if isinstance(output, tuple) else output
127
+ if torch.is_tensor(value):
128
+ layer_stats[name] = tensor_stats(value)
129
+
130
+ return hook
131
+
132
+ for name, module in model.named_modules():
133
+ if layer_pattern.match(name):
134
+ handles.append(module.register_forward_hook(make_hook(name)))
135
+
136
+ device_type = "cuda" if device.type == "cuda" else "cpu"
137
+ autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
138
+ with torch.no_grad(), torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled):
139
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
140
+
141
+ for handle in handles:
142
+ handle.remove()
143
+
144
+ result: dict[str, Any] = {
145
+ "checkpoint": str(checkpoint_path),
146
+ "loss": float(outputs["loss"].detach().cpu()),
147
+ "batch_size": args.batch_size,
148
+ "seq_len": args.seq_len,
149
+ "device": str(device),
150
+ "dtype": str(dtype),
151
+ "layers": layer_stats,
152
+ }
153
+ output_path = Path(args.output)
154
+ output_path.parent.mkdir(parents=True, exist_ok=True)
155
+ output_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
156
+ print(json.dumps(result, indent=2))
157
+
158
+
159
+ if __name__ == "__main__":
160
+ main()
code/TaoTrain/scripts/diagnostics/generate_checkpoint_samples.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate a few text samples from a saved checkpoint."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ import torch
11
+
12
+ REPO_ROOT = Path(__file__).resolve().parents[2]
13
+ SRC_ROOT = REPO_ROOT / "src"
14
+ if str(SRC_ROOT) not in sys.path:
15
+ sys.path.insert(0, str(SRC_ROOT))
16
+
17
+ from taoTrain.checkpointing.checkpoint import CheckpointManager
18
+ from taoTrain.config import ModelConfig
19
+ from taoTrain.inference.inferencer import Inferencer
20
+ from taoTrain.models import get_model
21
+
22
+
23
+ def clear_kernel_caches(model) -> None:
24
+ for module in model.modules():
25
+ clear = getattr(module, "clear_kernel_cache", None)
26
+ if callable(clear):
27
+ clear()
28
+
29
+
30
+ def generate_once(
31
+ model,
32
+ tokenizer,
33
+ prompt: str,
34
+ *,
35
+ device: torch.device,
36
+ max_new_tokens: int,
37
+ temperature: float,
38
+ top_p: float,
39
+ dtype: torch.dtype,
40
+ ) -> str:
41
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
42
+ generated = []
43
+ eos_token_id = getattr(tokenizer, "eos_token_id", None)
44
+ model.eval()
45
+ device_type = "cuda" if device.type == "cuda" else "cpu"
46
+ autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
47
+ with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled):
48
+ for _ in range(max_new_tokens):
49
+ clear_kernel_caches(model)
50
+ outputs = model(input_ids=input_ids, attention_mask=torch.ones_like(input_ids), labels=None)
51
+ logits = outputs["logits"][:, -1, :] / max(temperature, 1e-6)
52
+ if top_p < 1.0:
53
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
54
+ sorted_probs = torch.softmax(sorted_logits, dim=-1)
55
+ cumulative = torch.cumsum(sorted_probs, dim=-1)
56
+ remove = cumulative > top_p
57
+ remove[..., 1:] = remove[..., :-1].clone()
58
+ remove[..., 0] = False
59
+ indices_to_remove = sorted_indices[remove]
60
+ logits[0, indices_to_remove] = float("-inf")
61
+ probs = torch.softmax(logits, dim=-1)
62
+ next_token = torch.multinomial(probs, num_samples=1)
63
+ token_id = int(next_token.item())
64
+ if eos_token_id is not None and token_id == eos_token_id:
65
+ break
66
+ generated.append(token_id)
67
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
68
+ clear_kernel_caches(model)
69
+ return tokenizer.decode(generated, skip_special_tokens=True)
70
+
71
+
72
+ def main() -> None:
73
+ parser = argparse.ArgumentParser()
74
+ parser.add_argument("--checkpoint", required=True)
75
+ parser.add_argument("--tokenizer-path", required=True)
76
+ parser.add_argument("--output", required=True)
77
+ parser.add_argument("--prompt", action="append", default=[])
78
+ parser.add_argument("--max-new-tokens", type=int, default=80)
79
+ parser.add_argument("--temperature", type=float, default=0.8)
80
+ parser.add_argument("--top-p", type=float, default=0.9)
81
+ parser.add_argument("--device", default="cuda")
82
+ parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16")
83
+ args = parser.parse_args()
84
+
85
+ prompts = args.prompt or [
86
+ "The purpose of artificial intelligence is",
87
+ "In a small village,",
88
+ "<user>Hello, who are you?<assistant>",
89
+ ]
90
+ device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu")
91
+ dtype = {
92
+ "float32": torch.float32,
93
+ "bfloat16": torch.bfloat16,
94
+ "float16": torch.float16,
95
+ }[args.dtype]
96
+ tokenizer = Inferencer._load_tokenizer(args.tokenizer_path)
97
+ checkpoint_path = Path(args.checkpoint)
98
+ checkpoint = CheckpointManager(checkpoint_path.parent).load(checkpoint_path, device=device)
99
+ model_config = ModelConfig(**checkpoint.get("config", {}).get("model", {}))
100
+ model = get_model(model_config, device=device)
101
+ model.load_state_dict(checkpoint["model_state"], strict=False)
102
+
103
+ samples = []
104
+ for prompt in prompts:
105
+ text = generate_once(
106
+ model,
107
+ tokenizer,
108
+ prompt,
109
+ device=device,
110
+ max_new_tokens=args.max_new_tokens,
111
+ temperature=args.temperature,
112
+ top_p=args.top_p,
113
+ dtype=dtype,
114
+ )
115
+ samples.append({"prompt": prompt, "completion": text})
116
+
117
+ result = {
118
+ "checkpoint": args.checkpoint,
119
+ "tokenizer_path": args.tokenizer_path,
120
+ "device": str(device),
121
+ "dtype": str(dtype),
122
+ "max_new_tokens": args.max_new_tokens,
123
+ "temperature": args.temperature,
124
+ "top_p": args.top_p,
125
+ "samples": samples,
126
+ }
127
+ output = Path(args.output)
128
+ output.parent.mkdir(parents=True, exist_ok=True)
129
+ output.write_text(json.dumps(result, indent=2, ensure_ascii=False), encoding="utf-8")
130
+ print(json.dumps(result, indent=2, ensure_ascii=False))
131
+
132
+
133
+ if __name__ == "__main__":
134
+ main()
code/TaoTrain/scripts/diagnostics/sft_sanity_check.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Small SFT diagnostics for checkpoint quality and trainability.
2
+
3
+ This script intentionally bypasses the full trainer so it can answer one narrow
4
+ question quickly: can the checkpoint reduce response-only SFT loss on a tiny,
5
+ fixed batch?
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import json
12
+ import math
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ import torch
17
+
18
+ from taoTrain.checkpointing.checkpoint import CheckpointManager
19
+ from taoTrain.config import TrainingModeEnum, load_config
20
+ from taoTrain.core import create_model
21
+ from taoTrain.data.sft_utils import build_sft_sequence_tokens, parse_sft_record
22
+ try:
23
+ from taoTrain.data.sft_utils import build_response_only_next_token_labels
24
+ except ImportError:
25
+ def build_response_only_next_token_labels(input_ids: list[int], mask: list[int]) -> list[int]:
26
+ labels = [token_id if mask_value else -100 for token_id, mask_value in zip(input_ids, mask)]
27
+ return labels[1:] + [-100]
28
+ from taoTrain.data.tokenizer import SentencePieceTokenizerWrapper
29
+ from taoTrain.utils import set_seed
30
+
31
+
32
+ def load_tokenizer(tokenizer_path: str):
33
+ path = Path(tokenizer_path)
34
+ if path.suffix == ".model":
35
+ import sentencepiece as spm
36
+
37
+ sp = spm.SentencePieceProcessor()
38
+ sp.Load(str(path))
39
+ return SentencePieceTokenizerWrapper(sp)
40
+
41
+ from transformers import AutoTokenizer
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
44
+ if getattr(tokenizer, "pad_token", None) is None and getattr(tokenizer, "eos_token", None):
45
+ tokenizer.pad_token = tokenizer.eos_token
46
+ return tokenizer
47
+
48
+
49
+ def read_jsonl_records(path: str, limit: int) -> list[dict[str, Any]]:
50
+ records = []
51
+ with open(path, "r", encoding="utf-8") as handle:
52
+ for line in handle:
53
+ line = line.strip()
54
+ if not line:
55
+ continue
56
+ records.append(json.loads(line))
57
+ if len(records) >= limit:
58
+ break
59
+ return records
60
+
61
+
62
+ def build_batch(config, tokenizer, records: list[dict[str, Any]], device: torch.device) -> dict[str, torch.Tensor]:
63
+ input_rows = []
64
+ attention_rows = []
65
+ label_rows = []
66
+ train_tokens = []
67
+
68
+ for record in records:
69
+ turns, _ = parse_sft_record(record, config)
70
+ if not turns:
71
+ continue
72
+ input_ids, attention_mask, mask = build_sft_sequence_tokens(
73
+ turns=turns,
74
+ tokenizer=tokenizer,
75
+ user_token=getattr(config, "user_token", "<user>"),
76
+ assistant_token=getattr(config, "assistant_token", "<assistant>"),
77
+ max_seq_length=config.model.max_seq_length,
78
+ )
79
+ labels = build_response_only_next_token_labels(input_ids, mask)
80
+ input_rows.append(input_ids)
81
+ attention_rows.append(attention_mask)
82
+ label_rows.append(labels)
83
+ train_tokens.append(sum(1 for value in labels if value != -100))
84
+
85
+ if not input_rows:
86
+ raise ValueError("No valid SFT records found for the diagnostic batch")
87
+
88
+ return {
89
+ "input_ids": torch.tensor(input_rows, dtype=torch.long, device=device),
90
+ "attention_mask": torch.tensor(attention_rows, dtype=torch.long, device=device),
91
+ "labels": torch.tensor(label_rows, dtype=torch.long, device=device),
92
+ "train_tokens": torch.tensor(train_tokens, dtype=torch.long),
93
+ }
94
+
95
+
96
+ @torch.no_grad()
97
+ def score_batch(model, batch: dict[str, torch.Tensor], dtype: torch.dtype) -> float:
98
+ model.eval()
99
+ device_type = "cuda" if batch["input_ids"].is_cuda else "cpu"
100
+ enabled = device_type == "cuda" and dtype in (torch.float16, torch.bfloat16)
101
+ with torch.autocast(device_type=device_type, dtype=dtype, enabled=enabled):
102
+ outputs = model(
103
+ input_ids=batch["input_ids"],
104
+ attention_mask=batch["attention_mask"],
105
+ labels=batch["labels"],
106
+ )
107
+ return float(outputs["loss"].detach().cpu())
108
+
109
+
110
+ def grad_l2_norm(parameters) -> float:
111
+ total = 0.0
112
+ for parameter in parameters:
113
+ if parameter.grad is None:
114
+ continue
115
+ grad = parameter.grad.detach()
116
+ total += float(torch.sum(grad.float() * grad.float()).cpu())
117
+ return math.sqrt(total)
118
+
119
+
120
+ def grad_summary(named_parameters, max_items: int = 12) -> dict[str, Any]:
121
+ groups: dict[str, dict[str, Any]] = {}
122
+ worst = []
123
+ nonfinite = []
124
+
125
+ for name, parameter in named_parameters:
126
+ if parameter.grad is None:
127
+ continue
128
+ grad = parameter.grad.detach().float()
129
+ finite = torch.isfinite(grad)
130
+ finite_count = int(finite.sum().cpu())
131
+ numel = grad.numel()
132
+ finite_abs_max = float(grad[finite].abs().max().cpu()) if finite_count else float("inf")
133
+ has_nonfinite = finite_count != numel
134
+ if has_nonfinite:
135
+ nonfinite.append(name)
136
+
137
+ if ".layers." in name:
138
+ parts = name.split(".")
139
+ try:
140
+ idx = parts.index("layers")
141
+ group = "layer_" + parts[idx + 1]
142
+ except (ValueError, IndexError):
143
+ group = "layers"
144
+ else:
145
+ group = name.split(".", 1)[0]
146
+
147
+ entry = groups.setdefault(group, {
148
+ "numel": 0,
149
+ "finite": 0,
150
+ "nonfinite_tensors": 0,
151
+ "max_abs_grad": 0.0,
152
+ })
153
+ entry["numel"] += numel
154
+ entry["finite"] += finite_count
155
+ entry["nonfinite_tensors"] += int(has_nonfinite)
156
+ entry["max_abs_grad"] = max(entry["max_abs_grad"], finite_abs_max)
157
+ worst.append((finite_abs_max, name))
158
+
159
+ worst.sort(reverse=True, key=lambda item: item[0])
160
+ return {
161
+ "groups": groups,
162
+ "worst_tensors": [{"name": name, "max_abs_grad": value} for value, name in worst[:max_items]],
163
+ "nonfinite_tensors": nonfinite[:max_items],
164
+ "nonfinite_tensor_count": len(nonfinite),
165
+ }
166
+
167
+
168
+ def freeze_ssm_core_parameters(model) -> int:
169
+ frozen = 0
170
+ markers = (
171
+ ".ssm_lanes.",
172
+ ".ssm.",
173
+ )
174
+ for name, parameter in model.named_parameters():
175
+ if any(marker in name for marker in markers):
176
+ parameter.requires_grad_(False)
177
+ frozen += parameter.numel()
178
+ return frozen
179
+
180
+
181
+ def main() -> None:
182
+ parser = argparse.ArgumentParser()
183
+ parser.add_argument("--config", required=True)
184
+ parser.add_argument("--checkpoint", required=True)
185
+ parser.add_argument("--output", required=True)
186
+ parser.add_argument("--samples", type=int, default=2)
187
+ parser.add_argument("--steps", type=int, default=80)
188
+ parser.add_argument("--lr", type=float, default=3e-4)
189
+ parser.add_argument("--log-every", type=int, default=10)
190
+ parser.add_argument("--device", default="cuda")
191
+ parser.add_argument("--dtype", choices=["config", "float32", "float16", "bfloat16"], default="config")
192
+ parser.add_argument("--no-clip", action="store_true")
193
+ parser.add_argument("--freeze-ssm-core", action="store_true")
194
+ parser.add_argument("--ssm-branch-rms-norm", action="store_true")
195
+ parser.add_argument("--ssm-branch-clip-value", type=float, default=None)
196
+ parser.add_argument("--block-residual-rms-norm", action="store_true")
197
+ parser.add_argument("--block-residual-rms-target", type=float, default=None)
198
+ parser.add_argument("--seed", type=int, default=123)
199
+ args = parser.parse_args()
200
+
201
+ set_seed(args.seed)
202
+ config = load_config(args.config, TrainingModeEnum.SFT)
203
+ if args.ssm_branch_rms_norm:
204
+ config.model.ssm_branch_rms_norm = True
205
+ if args.ssm_branch_clip_value is not None:
206
+ config.model.ssm_branch_clip_value = args.ssm_branch_clip_value
207
+ if args.block_residual_rms_norm:
208
+ config.model.block_residual_rms_norm = True
209
+ if args.block_residual_rms_target is not None:
210
+ config.model.block_residual_rms_target = args.block_residual_rms_target
211
+ device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu")
212
+ if args.dtype == "float32":
213
+ dtype = torch.float32
214
+ elif args.dtype == "float16":
215
+ dtype = torch.float16
216
+ elif args.dtype == "bfloat16":
217
+ dtype = torch.bfloat16
218
+ else:
219
+ dtype = torch.bfloat16 if str(config.dtype) == "DataTypeEnum.BFLOAT16" or str(config.dtype) == "bfloat16" else torch.float32
220
+
221
+ tokenizer = load_tokenizer(config.dataset.tokenizer_path)
222
+ records = read_jsonl_records(config.dataset.jsonl_path, args.samples)
223
+ batch = build_batch(config, tokenizer, records, device)
224
+
225
+ model = create_model(config, device)
226
+ checkpoint = CheckpointManager(config.checkpoint_dir).load(args.checkpoint, device=device)
227
+ model.load_state_dict(checkpoint["model_state"], strict=False)
228
+ frozen_params = freeze_ssm_core_parameters(model) if args.freeze_ssm_core else 0
229
+
230
+ initial_loss = score_batch(model, batch, dtype)
231
+
232
+ trainable_params = [parameter for parameter in model.parameters() if parameter.requires_grad]
233
+ optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=0.0)
234
+ history = []
235
+ device_type = "cuda" if device.type == "cuda" else "cpu"
236
+ autocast_enabled = device_type == "cuda" and dtype in (torch.float16, torch.bfloat16)
237
+
238
+ model.train()
239
+ for step in range(1, args.steps + 1):
240
+ optimizer.zero_grad(set_to_none=True)
241
+ with torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled):
242
+ outputs = model(
243
+ input_ids=batch["input_ids"],
244
+ attention_mask=batch["attention_mask"],
245
+ labels=batch["labels"],
246
+ )
247
+ loss = outputs["loss"]
248
+ loss.backward()
249
+ grad_norm = grad_l2_norm(trainable_params)
250
+ stats = None
251
+ if step == 1 or step % args.log_every == 0 or step == args.steps:
252
+ stats = grad_summary(model.named_parameters())
253
+ if not args.no_clip:
254
+ torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
255
+ optimizer.step()
256
+
257
+ if step == 1 or step % args.log_every == 0 or step == args.steps:
258
+ item = {
259
+ "step": step,
260
+ "loss": float(loss.detach().cpu()),
261
+ "grad_l2_norm": grad_norm,
262
+ }
263
+ if stats is not None:
264
+ item["grad_summary"] = stats
265
+ history.append(item)
266
+
267
+ final_loss = score_batch(model, batch, dtype)
268
+ result = {
269
+ "checkpoint": str(Path(args.checkpoint)),
270
+ "config": str(Path(args.config)),
271
+ "dataset": config.dataset.jsonl_path,
272
+ "samples": len(records),
273
+ "sequence_length": config.model.max_seq_length,
274
+ "train_tokens_per_sample": batch["train_tokens"].tolist(),
275
+ "lr": args.lr,
276
+ "steps": args.steps,
277
+ "clip_grad_norm": not args.no_clip,
278
+ "freeze_ssm_core": args.freeze_ssm_core,
279
+ "ssm_branch_rms_norm": config.model.ssm_branch_rms_norm,
280
+ "ssm_branch_clip_value": config.model.ssm_branch_clip_value,
281
+ "block_residual_rms_norm": config.model.block_residual_rms_norm,
282
+ "block_residual_rms_target": config.model.block_residual_rms_target,
283
+ "frozen_params": frozen_params,
284
+ "trainable_params": sum(parameter.numel() for parameter in trainable_params),
285
+ "initial_loss": initial_loss,
286
+ "final_loss": final_loss,
287
+ "loss_delta": final_loss - initial_loss,
288
+ "history": history,
289
+ "device": str(device),
290
+ "dtype": str(dtype),
291
+ }
292
+
293
+ output = Path(args.output)
294
+ output.parent.mkdir(parents=True, exist_ok=True)
295
+ output.write_text(json.dumps(result, indent=2), encoding="utf-8")
296
+ print(json.dumps(result, indent=2))
297
+
298
+
299
+ if __name__ == "__main__":
300
+ main()
code/TaoTrain/scripts/profile_taonet_components.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Profile TaoNet and TaoNet-SSM component costs on synthetic token batches.
2
+
3
+ The real-token benchmark tells us end-to-end quality and throughput. This
4
+ script is the companion microscope: it times forward components such as the
5
+ SSM core, gates, projections, FFN, embeddings, and output head so hardware work
6
+ targets the largest measured costs.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ from collections import defaultdict
13
+ from contextlib import nullcontext
14
+ from contextlib import redirect_stdout
15
+ import io
16
+ import json
17
+ import os
18
+ from pathlib import Path
19
+ import platform
20
+ import sys
21
+ import time
22
+ from typing import Any
23
+
24
+ import torch
25
+
26
+ REPO_ROOT = Path(__file__).resolve().parents[1]
27
+ SRC_ROOT = REPO_ROOT / "src"
28
+ if str(SRC_ROOT) not in sys.path:
29
+ sys.path.insert(0, str(SRC_ROOT))
30
+
31
+ from taoTrain.config import ModelConfig
32
+ from taoTrain.models import get_model
33
+
34
+
35
+ DTYPES = {
36
+ "float32": torch.float32,
37
+ "fp32": torch.float32,
38
+ "float16": torch.float16,
39
+ "fp16": torch.float16,
40
+ "bfloat16": torch.bfloat16,
41
+ "bf16": torch.bfloat16,
42
+ }
43
+
44
+
45
+ def synchronize(device: torch.device) -> None:
46
+ if device.type == "cuda":
47
+ torch.cuda.synchronize(device)
48
+
49
+
50
+ def reset_memory(device: torch.device) -> None:
51
+ if device.type == "cuda":
52
+ torch.cuda.reset_peak_memory_stats(device)
53
+
54
+
55
+ def memory_stats(device: torch.device) -> dict[str, float | None]:
56
+ if device.type != "cuda":
57
+ return {"peak_allocated_mb": None, "peak_reserved_mb": None}
58
+ return {
59
+ "peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
60
+ "peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
61
+ }
62
+
63
+
64
+ class ComponentTimer:
65
+ def __init__(self, device: torch.device) -> None:
66
+ self.device = device
67
+ self.records: dict[str, list[float]] = defaultdict(list)
68
+ self._starts: dict[int, Any] = {}
69
+ self._handles = []
70
+
71
+ def _record_ms(self, name: str, start: Any) -> None:
72
+ if self.device.type == "cuda":
73
+ end = torch.cuda.Event(enable_timing=True)
74
+ end.record()
75
+ end.synchronize()
76
+ self.records[name].append(float(start.elapsed_time(end)))
77
+ else:
78
+ self.records[name].append((time.perf_counter() - start) * 1000.0)
79
+
80
+ def add(self, module: torch.nn.Module, name: str) -> None:
81
+ def pre_hook(mod, inputs):
82
+ del inputs
83
+ if self.device.type == "cuda":
84
+ start = torch.cuda.Event(enable_timing=True)
85
+ start.record()
86
+ else:
87
+ start = time.perf_counter()
88
+ self._starts[id(mod)] = start
89
+
90
+ def post_hook(mod, inputs, output):
91
+ del inputs, output
92
+ start = self._starts.pop(id(mod), None)
93
+ if start is not None:
94
+ self._record_ms(name, start)
95
+
96
+ self._handles.append(module.register_forward_pre_hook(pre_hook))
97
+ self._handles.append(module.register_forward_hook(post_hook))
98
+
99
+ def close(self) -> None:
100
+ for handle in self._handles:
101
+ handle.remove()
102
+ self._handles.clear()
103
+
104
+ def summary(self) -> list[dict[str, float | str | int]]:
105
+ rows = []
106
+ for name, values in sorted(self.records.items()):
107
+ if not values:
108
+ continue
109
+ rows.append(
110
+ {
111
+ "component": name,
112
+ "calls": len(values),
113
+ "mean_ms": sum(values) / len(values),
114
+ "total_ms": sum(values),
115
+ "min_ms": min(values),
116
+ "max_ms": max(values),
117
+ }
118
+ )
119
+ rows.sort(key=lambda row: float(row["total_ms"]), reverse=True)
120
+ return rows
121
+
122
+
123
+ def build_config(args: argparse.Namespace, architecture: str) -> ModelConfig:
124
+ d_latent_kv = args.d_latent_kv if args.d_latent_kv is not None else int(args.hidden_dim * 0.75)
125
+ d_rope = args.d_rope if args.d_rope is not None else args.hidden_dim // args.num_heads
126
+ hidden_dim_ff = args.hidden_dim_ff if args.hidden_dim_ff is not None else args.hidden_dim * 4
127
+ return ModelConfig(
128
+ architecture_type=architecture,
129
+ vocab_size=args.vocab_size,
130
+ hidden_dim=args.hidden_dim,
131
+ num_layers=args.num_layers,
132
+ num_heads=args.num_heads,
133
+ max_seq_length=args.seq_len,
134
+ d_latent_kv=d_latent_kv,
135
+ d_rope=d_rope,
136
+ hidden_dim_ff=hidden_dim_ff,
137
+ dropout=args.dropout,
138
+ gqa_groups=args.gqa_groups,
139
+ rope_scale=args.rope_scale,
140
+ yarn_alpha=args.yarn_alpha,
141
+ init_std=args.init_std,
142
+ ssm_core=args.ssm_core,
143
+ ssm_hidden_dim=args.ssm_hidden_dim,
144
+ ssm_mixer_dim=args.ssm_mixer_dim,
145
+ ssm_rank=args.ssm_rank,
146
+ ssm_max_low_rank_scale=args.ssm_max_low_rank_scale,
147
+ ssm_kernel_mode=args.ssm_kernel_mode,
148
+ ssm_kernel_threshold=args.ssm_kernel_threshold,
149
+ ssm_dt_min=args.ssm_dt_min,
150
+ ssm_dt_max=args.ssm_dt_max,
151
+ ssm_dt_init=args.ssm_dt_init,
152
+ ssm_use_padding_mask=False,
153
+ ssm_activation=args.ssm_activation,
154
+ ssm_gate=args.ssm_gate,
155
+ ssm_input_gate=args.ssm_input_gate,
156
+ ssm_layer_scale_init=args.ssm_layer_scale_init,
157
+ ssm_local_shift=args.ssm_local_shift,
158
+ ssm_local_shift_init=args.ssm_local_shift_init,
159
+ ssm_local_shift_per_channel=args.ssm_local_shift_per_channel,
160
+ )
161
+
162
+
163
+ def add_component_hooks(model: torch.nn.Module, architecture: str, timer: ComponentTimer) -> None:
164
+ timer.add(model.token_embedding, "embedding")
165
+ timer.add(model.final_norm, "final_norm")
166
+ timer.add(model.output_head, "output_head")
167
+ for layer_index, block in enumerate(model.blocks):
168
+ prefix = f"block{layer_index}"
169
+ if architecture == "taonet_ssm":
170
+ mixer = block.mixer
171
+ timer.add(mixer.norm, f"{prefix}.mixer.norm")
172
+ if mixer.input_gate is not None:
173
+ timer.add(mixer.input_gate, f"{prefix}.mixer.input_gate")
174
+ timer.add(mixer.input_proj, f"{prefix}.mixer.input_proj")
175
+ timer.add(mixer.ssm, f"{prefix}.mixer.ssm_core")
176
+ timer.add(mixer.activation, f"{prefix}.mixer.activation")
177
+ timer.add(mixer.out_proj, f"{prefix}.mixer.out_proj")
178
+ if mixer.output_gate is not None:
179
+ timer.add(mixer.output_gate, f"{prefix}.mixer.output_gate")
180
+ timer.add(mixer.proj_dropout, f"{prefix}.mixer.dropout")
181
+ else:
182
+ mla = block.mla
183
+ timer.add(mla.norm, f"{prefix}.attention.norm")
184
+ timer.add(mla.q_proj, f"{prefix}.attention.q_proj")
185
+ timer.add(mla.k_proj, f"{prefix}.attention.k_proj")
186
+ timer.add(mla.v_proj, f"{prefix}.attention.v_proj")
187
+ timer.add(mla.out_proj, f"{prefix}.attention.out_proj")
188
+ timer.add(mla.attn_dropout, f"{prefix}.attention.attn_dropout")
189
+ timer.add(mla.proj_dropout, f"{prefix}.attention.proj_dropout")
190
+ timer.add(block.ff_norm, f"{prefix}.ff.norm")
191
+ timer.add(block.ff_gate, f"{prefix}.ff.gate")
192
+ timer.add(block.ff_value, f"{prefix}.ff.value")
193
+ timer.add(block.ff_out, f"{prefix}.ff.out")
194
+
195
+
196
+ def time_repeats(fn, *, device: torch.device, warmup: int, repeats: int) -> dict[str, float]:
197
+ for _ in range(warmup):
198
+ fn()
199
+ synchronize(device)
200
+
201
+ latencies = []
202
+ for _ in range(repeats):
203
+ reset_memory(device)
204
+ synchronize(device)
205
+ start = time.perf_counter()
206
+ fn()
207
+ synchronize(device)
208
+ latencies.append(time.perf_counter() - start)
209
+ mean_s = sum(latencies) / len(latencies)
210
+ return {
211
+ "mean_ms": mean_s * 1000.0,
212
+ "min_ms": min(latencies) * 1000.0,
213
+ "max_ms": max(latencies) * 1000.0,
214
+ }
215
+
216
+
217
+ def profile_architecture(
218
+ args: argparse.Namespace,
219
+ *,
220
+ architecture: str,
221
+ device: torch.device,
222
+ dtype: torch.dtype,
223
+ ) -> dict[str, Any]:
224
+ torch.manual_seed(args.seed)
225
+ if device.type == "cuda":
226
+ torch.cuda.manual_seed_all(args.seed)
227
+
228
+ config = build_config(args, architecture)
229
+ with redirect_stdout(io.StringIO()):
230
+ model = get_model(config, device=device)
231
+ model.train()
232
+
233
+ input_ids = torch.randint(
234
+ low=0,
235
+ high=args.vocab_size,
236
+ size=(args.batch_size, args.seq_len),
237
+ device=device,
238
+ )
239
+ labels = torch.randint(
240
+ low=0,
241
+ high=args.vocab_size,
242
+ size=(args.batch_size, args.seq_len),
243
+ device=device,
244
+ )
245
+ attention_mask = torch.ones_like(input_ids)
246
+
247
+ autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
248
+
249
+ def autocast_context():
250
+ if not autocast_enabled:
251
+ return nullcontext()
252
+ return torch.autocast(device_type=device.type, dtype=dtype, enabled=True)
253
+
254
+ def forward_only() -> torch.Tensor:
255
+ with torch.no_grad():
256
+ with autocast_context():
257
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
258
+ return outputs["loss"]
259
+
260
+ def forward_backward() -> torch.Tensor:
261
+ model.zero_grad(set_to_none=True)
262
+ with autocast_context():
263
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
264
+ loss = outputs["loss"]
265
+ loss.backward()
266
+ return loss
267
+
268
+ no_timer_forward = time_repeats(
269
+ forward_only,
270
+ device=device,
271
+ warmup=args.warmup,
272
+ repeats=args.repeats,
273
+ )
274
+ no_timer_backward = time_repeats(
275
+ forward_backward,
276
+ device=device,
277
+ warmup=args.warmup,
278
+ repeats=args.repeats,
279
+ )
280
+
281
+ timer = ComponentTimer(device)
282
+ add_component_hooks(model, architecture, timer)
283
+ try:
284
+ for _ in range(args.component_warmup):
285
+ forward_only()
286
+ synchronize(device)
287
+ for _ in range(args.component_repeats):
288
+ forward_only()
289
+ synchronize(device)
290
+ finally:
291
+ timer.close()
292
+
293
+ tokens = args.batch_size * args.seq_len
294
+ component_rows = timer.summary()
295
+ return {
296
+ "architecture": architecture,
297
+ "total_params": sum(param.numel() for param in model.parameters()),
298
+ "trainable_params": sum(param.numel() for param in model.parameters() if param.requires_grad),
299
+ "forward": {
300
+ **no_timer_forward,
301
+ "tokens_per_s": tokens / max(no_timer_forward["mean_ms"] / 1000.0, 1e-12),
302
+ },
303
+ "forward_backward": {
304
+ **no_timer_backward,
305
+ "tokens_per_s": tokens / max(no_timer_backward["mean_ms"] / 1000.0, 1e-12),
306
+ **memory_stats(device),
307
+ },
308
+ "components_forward": component_rows,
309
+ }
310
+
311
+
312
+ def main() -> int:
313
+ parser = argparse.ArgumentParser(description=__doc__)
314
+ parser.add_argument("--architectures", default="taonet,taonet_ssm")
315
+ parser.add_argument("--vocab-size", type=int, default=8192)
316
+ parser.add_argument("--batch-size", type=int, default=32)
317
+ parser.add_argument("--seq-len", type=int, default=512)
318
+ parser.add_argument("--hidden-dim", type=int, default=256)
319
+ parser.add_argument("--num-layers", type=int, default=4)
320
+ parser.add_argument("--num-heads", type=int, default=4)
321
+ parser.add_argument("--d-latent-kv", type=int, default=None)
322
+ parser.add_argument("--d-rope", type=int, default=None)
323
+ parser.add_argument("--hidden-dim-ff", type=int, default=None)
324
+ parser.add_argument("--dropout", type=float, default=0.0)
325
+ parser.add_argument("--gqa-groups", type=int, default=1)
326
+ parser.add_argument("--rope-scale", type=float, default=40.0)
327
+ parser.add_argument("--yarn-alpha", type=float, default=1.0)
328
+ parser.add_argument("--init-std", type=float, default=0.02)
329
+ parser.add_argument("--ssm-core", choices=["gamma_s4", "dplr"], default="dplr")
330
+ parser.add_argument("--ssm-hidden-dim", type=int, default=16)
331
+ parser.add_argument("--ssm-mixer-dim", type=int, default=128)
332
+ parser.add_argument("--ssm-rank", type=int, default=1)
333
+ parser.add_argument("--ssm-max-low-rank-scale", type=float, default=0.1)
334
+ parser.add_argument("--ssm-kernel-mode", choices=["auto", "conv", "conv_transfer", "recurrent"], default="conv")
335
+ parser.add_argument("--ssm-kernel-threshold", type=int, default=1)
336
+ parser.add_argument("--ssm-dt-min", type=float, default=1e-3)
337
+ parser.add_argument("--ssm-dt-max", type=float, default=1e-1)
338
+ parser.add_argument("--ssm-dt-init", type=float, default=1e-2)
339
+ parser.add_argument("--ssm-activation", choices=["gelu", "silu", "identity", "linear"], default="gelu")
340
+ parser.add_argument("--ssm-gate", action=argparse.BooleanOptionalAction, default=True)
341
+ parser.add_argument("--ssm-input-gate", action=argparse.BooleanOptionalAction, default=True)
342
+ parser.add_argument("--ssm-layer-scale-init", type=float, default=0.1)
343
+ parser.add_argument("--ssm-local-shift", action=argparse.BooleanOptionalAction, default=True)
344
+ parser.add_argument("--ssm-local-shift-init", type=float, default=0.1)
345
+ parser.add_argument("--ssm-local-shift-per-channel", action=argparse.BooleanOptionalAction, default=True)
346
+ parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16")
347
+ parser.add_argument("--device", default="auto")
348
+ parser.add_argument("--warmup", type=int, default=2)
349
+ parser.add_argument("--repeats", type=int, default=5)
350
+ parser.add_argument("--component-warmup", type=int, default=1)
351
+ parser.add_argument("--component-repeats", type=int, default=3)
352
+ parser.add_argument("--seed", type=int, default=42)
353
+ parser.add_argument("--output", type=Path, default=None)
354
+ args = parser.parse_args()
355
+
356
+ if args.device == "auto":
357
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
358
+ else:
359
+ device = torch.device(args.device)
360
+ dtype = DTYPES[args.dtype]
361
+ if device.type == "cuda":
362
+ torch.backends.cuda.matmul.allow_tf32 = True
363
+ torch.backends.cudnn.allow_tf32 = True
364
+
365
+ results = [
366
+ profile_architecture(args, architecture=architecture.strip(), device=device, dtype=dtype)
367
+ for architecture in args.architectures.split(",")
368
+ if architecture.strip()
369
+ ]
370
+ report = {
371
+ "metadata": {
372
+ "python": platform.python_version(),
373
+ "platform": platform.platform(),
374
+ "torch": torch.__version__,
375
+ "cuda_available": torch.cuda.is_available(),
376
+ "cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None,
377
+ "device": str(device),
378
+ "dtype": str(dtype).replace("torch.", ""),
379
+ "args": vars(args) | {"output": str(args.output) if args.output else None},
380
+ },
381
+ "results": results,
382
+ }
383
+
384
+ text = json.dumps(report, indent=2, sort_keys=True, default=str)
385
+ print(text)
386
+ if args.output is not None:
387
+ args.output.parent.mkdir(parents=True, exist_ok=True)
388
+ args.output.write_text(text, encoding="utf-8")
389
+ return 0
390
+
391
+
392
+ if __name__ == "__main__":
393
+ raise SystemExit(main())
code/TaoTrain/scripts/remote/job_status.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ RUN_ID="${RUN_ID:-}"
5
+ JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
6
+
7
+ if [[ -z "$RUN_ID" ]]; then
8
+ echo "RUN_ID is required" >&2
9
+ exit 2
10
+ fi
11
+
12
+ safe_run_id="$(printf '%s' "$RUN_ID" | tr -c 'A-Za-z0-9_.-' '_')"
13
+ job_dir="${JOB_ROOT%/}/${safe_run_id}"
14
+
15
+ if [[ ! -d "$job_dir" ]]; then
16
+ echo "Job directory not found: $job_dir" >&2
17
+ exit 1
18
+ fi
19
+
20
+ echo "== status.json =="
21
+ cat "$job_dir/status.json" 2>/dev/null || true
22
+ echo
23
+ echo "== markers =="
24
+ ls -1 "$job_dir"/DONE "$job_dir"/FAILED 2>/dev/null || true
25
+ echo
26
+ echo "== tmux =="
27
+ tmux ls 2>/dev/null | grep -F "taotern_${safe_run_id}" || true
28
+ echo
29
+ echo "== recent log =="
30
+ tail -n "${TAIL_LINES:-80}" "$job_dir/train.log" 2>/dev/null || true
31
+ echo
32
+ echo "== outputs =="
33
+ find "$job_dir/outputs" -maxdepth 2 -type f 2>/dev/null | sort | tail -n 40 || true
34
+ echo
35
+ echo "== checkpoints =="
36
+ find "$job_dir/checkpoints" -maxdepth 1 -type f 2>/dev/null | sort | tail -n 20 || true
code/TaoTrain/scripts/remote/run_200m_base_suite.sh ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
5
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
6
+ SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
7
+ PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
8
+ REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
9
+ OUTPUT_BASE="${REPOBRIDGE_OUTPUT_DIR:-$REMOTE_REPO/results/200m-base-suite}"
10
+ CHECKPOINT_BASE="${TAOTERN_CHECKPOINT_DIR:-$OUTPUT_BASE/checkpoints}"
11
+
12
+ # Stage-1 defaults are intentionally modest. Increase these through environment
13
+ # variables after the 200M shapes are stable on the RTX5090.
14
+ MAX_TOKENS="${MAX_TOKENS:-50000000}"
15
+ MAX_RECORDS="${MAX_RECORDS:-100000}"
16
+ TRAIN_STEPS="${TRAIN_STEPS:-200}"
17
+ EVAL_BATCHES="${EVAL_BATCHES:-16}"
18
+ BATCH_SIZES="${BATCH_SIZES:-4,8}"
19
+ SEQ_LEN="${SEQ_LEN:-512}"
20
+ LEARNING_RATE="${LEARNING_RATE:-0.0006}"
21
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
22
+ DRY_RUN="${DRY_RUN:-0}"
23
+
24
+ export PYTHONPATH="$REMOTE_REPO/src:$SSM_REPO_PATH"
25
+ mkdir -p "$OUTPUT_BASE" "$CHECKPOINT_BASE"
26
+
27
+ run_variant() {
28
+ local variant="$1"
29
+ shift
30
+ local output_dir="$OUTPUT_BASE/$variant"
31
+ local checkpoint_dir="$CHECKPOINT_BASE/$variant"
32
+ mkdir -p "$output_dir" "$checkpoint_dir"
33
+
34
+ local cmd="$PYTHON_BIN scripts/benchmark_taonet_real_tokens.py \
35
+ --data-path $DATA_PATH \
36
+ --text-field text \
37
+ --tokenizer-type sentencepiece \
38
+ --tokenizer-path $TOKENIZER_PATH \
39
+ --max-records $MAX_RECORDS \
40
+ --max-tokens $MAX_TOKENS \
41
+ --eval-fraction 0.1 \
42
+ --batch-sizes $BATCH_SIZES \
43
+ --seq-len $SEQ_LEN \
44
+ --dtype bf16 \
45
+ --device cuda \
46
+ --warmup 1 \
47
+ --repeats 2 \
48
+ --backward \
49
+ --train-steps $TRAIN_STEPS \
50
+ --learning-rate $LEARNING_RATE \
51
+ --weight-decay $WEIGHT_DECAY \
52
+ --eval-batches $EVAL_BATCHES \
53
+ --output-dir $output_dir \
54
+ --resume-completed \
55
+ --incremental-output \
56
+ --save-case-checkpoints \
57
+ --checkpoint-dir $checkpoint_dir \
58
+ $*"
59
+
60
+ printf '\n=== 200M variant: %s ===\n' "$variant"
61
+ printf '%s\n' "$cmd"
62
+ if [ "$DRY_RUN" = "1" ]; then
63
+ return 0
64
+ fi
65
+ eval "$cmd"
66
+ }
67
+
68
+ run_variant attention_196m \
69
+ --architectures taonet \
70
+ --hidden-dim 960 \
71
+ --num-layers 16 \
72
+ --num-heads 8 \
73
+ --d-latent-kv 720 \
74
+ --d-rope 120 \
75
+ --hidden-dim-ff 2880
76
+
77
+ run_variant pure_ssm_196m_hadamard \
78
+ --architectures taonet_ssm \
79
+ --hidden-dim 1024 \
80
+ --num-layers 18 \
81
+ --num-heads 8 \
82
+ --d-latent-kv 768 \
83
+ --d-rope 128 \
84
+ --hidden-dim-ff 3072 \
85
+ --ssm-core dplr \
86
+ --ssm-hidden-dims 16 \
87
+ --ssm-mixer-dims 256 \
88
+ --ssm-num-lanes-list 2 \
89
+ --ssm-lane-combine channel \
90
+ --ssm-lane-modes split \
91
+ --ssm-split-mixes hadamard \
92
+ --ssm-rank 1 \
93
+ --ssm-kernel-mode conv \
94
+ --no-ssm-finite-tail-correction \
95
+ --ssm-gate-types channel \
96
+ --ssm-local-shift \
97
+ --ssm-local-shift-per-channel \
98
+ --ssm-local-shift-init 0.1
99
+
100
+ run_variant pure_ssm_196m_nomix \
101
+ --architectures taonet_ssm \
102
+ --hidden-dim 1024 \
103
+ --num-layers 18 \
104
+ --num-heads 8 \
105
+ --d-latent-kv 768 \
106
+ --d-rope 128 \
107
+ --hidden-dim-ff 3072 \
108
+ --ssm-core dplr \
109
+ --ssm-hidden-dims 16 \
110
+ --ssm-mixer-dims 256 \
111
+ --ssm-num-lanes-list 2 \
112
+ --ssm-lane-combine channel \
113
+ --ssm-lane-modes split \
114
+ --ssm-split-mixes none \
115
+ --ssm-rank 1 \
116
+ --ssm-kernel-mode conv \
117
+ --no-ssm-finite-tail-correction \
118
+ --ssm-gate-types channel \
119
+ --ssm-local-shift \
120
+ --ssm-local-shift-per-channel \
121
+ --ssm-local-shift-init 0.1
122
+
123
+ run_variant hybrid_ssm_first_199m \
124
+ --architectures taonet_hybrid \
125
+ --hidden-dim 1024 \
126
+ --num-layers 16 \
127
+ --num-heads 8 \
128
+ --d-latent-kv 768 \
129
+ --d-rope 128 \
130
+ --hidden-dim-ff 3072 \
131
+ --ssm-core dplr \
132
+ --ssm-hidden-dims 32 \
133
+ --ssm-mixer-dims 256 \
134
+ --ssm-num-lanes-list 2 \
135
+ --ssm-lane-combine channel \
136
+ --ssm-lane-modes split \
137
+ --ssm-split-mixes hadamard \
138
+ --ssm-rank 1 \
139
+ --ssm-kernel-mode conv \
140
+ --no-ssm-finite-tail-correction \
141
+ --ssm-gate-types channel \
142
+ --hybrid-patterns ssm_first \
143
+ --ssm-local-shift \
144
+ --ssm-local-shift-per-channel \
145
+ --ssm-local-shift-init 0.1
146
+
147
+ if [ "$DRY_RUN" != "1" ]; then
148
+ "$PYTHON_BIN" scripts/summarize_taonet_benchmark_suite.py --suite-dir "$OUTPUT_BASE"
149
+ fi
code/TaoTrain/scripts/remote/run_200m_branch_only_chat.sh ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
5
+ SFT_DATA_PATH="${SFT_DATA_PATH:-/home/student/Data/TaoData/sft.jsonl}"
6
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
7
+ SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
8
+ PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
9
+ REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
10
+ OUTPUT_BASE="${REPOBRIDGE_OUTPUT_DIR:-$REMOTE_REPO/results/200m-branch-only-chat}"
11
+ CHECKPOINT_BASE="${TAOTERN_CHECKPOINT_DIR:-$OUTPUT_BASE/checkpoints}"
12
+
13
+ SEQ_LEN="${SEQ_LEN:-512}"
14
+ BATCH_SIZE="${BATCH_SIZE:-8}"
15
+ PRETRAIN_TOKENS="${PRETRAIN_TOKENS:-4000000000}"
16
+ SFT_STEPS="${SFT_STEPS:-50000}"
17
+ PRETRAIN_LR="${PRETRAIN_LR:-0.0008}"
18
+ SFT_LR="${SFT_LR:-0.00005}"
19
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
20
+ LOG_EVERY="${LOG_EVERY:-100}"
21
+ SAVE_EVERY="${SAVE_EVERY:-100000}"
22
+ SFT_SAVE_EVERY="${SFT_SAVE_EVERY:-10000}"
23
+ TOKENIZER_THREADS="${TOKENIZER_THREADS:-8}"
24
+ SAMPLES_PER_CHUNK="${SAMPLES_PER_CHUNK:-2000}"
25
+ BLOCK_RESIDUAL_RMS_CAP="${BLOCK_RESIDUAL_RMS_CAP:-}"
26
+
27
+ ceil_div() {
28
+ local numerator="$1"
29
+ local denominator="$2"
30
+ echo $(( (numerator + denominator - 1) / denominator ))
31
+ }
32
+
33
+ PRETRAIN_STEPS="${PRETRAIN_STEPS:-$(ceil_div "$PRETRAIN_TOKENS" $((BATCH_SIZE * SEQ_LEN)))}"
34
+
35
+ export PYTHONPATH="$REMOTE_REPO/src:$SSM_REPO_PATH"
36
+ mkdir -p "$OUTPUT_BASE/configs" "$OUTPUT_BASE/diagnostics" "$CHECKPOINT_BASE/pretrain" "$CHECKPOINT_BASE/sft"
37
+
38
+ cap_yaml=""
39
+ if [[ -n "$BLOCK_RESIDUAL_RMS_CAP" ]]; then
40
+ cap_yaml=" block_residual_rms_cap: $BLOCK_RESIDUAL_RMS_CAP"
41
+ fi
42
+
43
+ cat > "$OUTPUT_BASE/run_plan.json" <<JSON
44
+ {
45
+ "purpose": "200m_branch_only_pure_ssm_4b_pretrain_plus_sft_chatbot_attempt",
46
+ "candidate": "pure_ssm_196m_branch_rms_only",
47
+ "pretrain_token_positions": $PRETRAIN_TOKENS,
48
+ "pretrain_steps": $PRETRAIN_STEPS,
49
+ "batch_size": $BATCH_SIZE,
50
+ "seq_len": $SEQ_LEN,
51
+ "sft_steps": $SFT_STEPS,
52
+ "pretrain_lr": $PRETRAIN_LR,
53
+ "sft_lr": $SFT_LR,
54
+ "save_every_steps": $SAVE_EVERY,
55
+ "block_residual_rms_cap": "${BLOCK_RESIDUAL_RMS_CAP:-null}",
56
+ "notes": [
57
+ "Uses streaming JSONL training through TaoTrain CLI.",
58
+ "Counts token positions as batch_size * seq_len * max_steps.",
59
+ "Keeps SSM branch RMS normalization enabled and block residual RMS normalization disabled.",
60
+ "Runs corrected response-only SFT after pretraining."
61
+ ]
62
+ }
63
+ JSON
64
+
65
+ PRETRAIN_CONFIG="$OUTPUT_BASE/configs/pretrain.yaml"
66
+ cat > "$PRETRAIN_CONFIG" <<YAML
67
+ model:
68
+ architecture_type: taonet_ssm
69
+ vocab_size: 8192
70
+ hidden_dim: 1024
71
+ num_layers: 18
72
+ num_heads: 8
73
+ max_seq_length: $SEQ_LEN
74
+ d_latent_kv: 768
75
+ d_rope: 128
76
+ hidden_dim_ff: 3072
77
+ dropout: 0.0
78
+ gqa_groups: 1
79
+ use_factorized_embedding: false
80
+ d_embed_rank: 96
81
+ init_std: 0.02
82
+ ssm_core: dplr
83
+ ssm_hidden_dim: 32
84
+ ssm_mixer_dim: 256
85
+ ssm_num_lanes: 2
86
+ ssm_lane_combine: channel
87
+ ssm_lane_mode: split
88
+ ssm_split_mix: none
89
+ ssm_rank: 1
90
+ ssm_max_low_rank_scale: 0.1
91
+ ssm_finite_tail_correction: false
92
+ ssm_discretization: bilinear
93
+ ssm_kernel_mode: conv
94
+ ssm_kernel_threshold: 64
95
+ ssm_dt_min: 1e-3
96
+ ssm_dt_max: 1e-1
97
+ ssm_dt_init: 1e-2
98
+ ssm_use_d: true
99
+ ssm_activation: gelu
100
+ ssm_gate: true
101
+ ssm_input_gate: true
102
+ ssm_gate_type: channel
103
+ ssm_use_padding_mask: false
104
+ ssm_layer_scale_init: 0.1
105
+ ssm_branch_rms_norm: true
106
+ block_residual_rms_norm: false
107
+ $cap_yaml
108
+ ssm_local_shift: true
109
+ ssm_local_shift_init: 0.1
110
+ ssm_local_shift_per_channel: true
111
+
112
+ dataset:
113
+ local: true
114
+ jsonl_path: $DATA_PATH
115
+ text_field: text
116
+ tokenizer_type: sentencepiece
117
+ tokenizer_path: $TOKENIZER_PATH
118
+ samples_per_chunk: $SAMPLES_PER_CHUNK
119
+ tokenizer_threads: $TOKENIZER_THREADS
120
+
121
+ sequence_length: $SEQ_LEN
122
+ batch_size: $BATCH_SIZE
123
+ num_epochs: 100000
124
+ max_steps: $PRETRAIN_STEPS
125
+ gradient_accumulation_steps: 1
126
+ max_grad_norm: 1.0
127
+
128
+ optimizer:
129
+ optimizer_type: adamw
130
+ learning_rate: $PRETRAIN_LR
131
+ weight_decay: $WEIGHT_DECAY
132
+ betas: [0.9, 0.999]
133
+ eps: 1e-8
134
+
135
+ scheduler:
136
+ scheduler_type: constant
137
+ warmup_steps: 1000
138
+ warmup_ratio: 0.0
139
+
140
+ dtype: bfloat16
141
+ device: cuda
142
+ checkpoint_dir: $CHECKPOINT_BASE/pretrain
143
+ save_every_steps: $SAVE_EVERY
144
+ save_best_model: false
145
+ keep_last_n_checkpoints: 3
146
+ eval_every_steps: $SAVE_EVERY
147
+ eval_samples: 32
148
+ log_every_steps: $LOG_EVERY
149
+ aim_repo: $OUTPUT_BASE/.aim-pretrain
150
+ seed: 43
151
+ num_workers: 0
152
+ pin_memory: true
153
+ YAML
154
+
155
+ printf '\n============================================================\n'
156
+ printf '200M branch-only pure SSM pretrain\n'
157
+ printf 'token_positions=%s batch=%s seq_len=%s steps=%s save_every=%s\n' \
158
+ "$PRETRAIN_TOKENS" "$BATCH_SIZE" "$SEQ_LEN" "$PRETRAIN_STEPS" "$SAVE_EVERY"
159
+ printf '============================================================\n'
160
+
161
+ "$PYTHON_BIN" -m taoTrain.cli pretrain --config "$PRETRAIN_CONFIG"
162
+
163
+ PRETRAIN_CKPT="$CHECKPOINT_BASE/pretrain/final_model.pt"
164
+ if [[ ! -f "$PRETRAIN_CKPT" ]]; then
165
+ echo "Expected pretrain checkpoint missing: $PRETRAIN_CKPT" >&2
166
+ exit 2
167
+ fi
168
+
169
+ "$PYTHON_BIN" scripts/diagnostics/activation_probe.py \
170
+ --checkpoint "$PRETRAIN_CKPT" \
171
+ --tokenizer-path "$TOKENIZER_PATH" \
172
+ --data-path "$DATA_PATH" \
173
+ --text-field text \
174
+ --output "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_final.json" \
175
+ --batch-size 2 \
176
+ --seq-len "$SEQ_LEN" \
177
+ --device cuda \
178
+ --dtype bfloat16
179
+
180
+ "$PYTHON_BIN" scripts/diagnostics/generate_checkpoint_samples.py \
181
+ --checkpoint "$PRETRAIN_CKPT" \
182
+ --tokenizer-path "$TOKENIZER_PATH" \
183
+ --output "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_final.json" \
184
+ --max-new-tokens 160 \
185
+ --temperature 0.8 \
186
+ --top-p 0.9 \
187
+ --prompt "The purpose of artificial intelligence is" \
188
+ --prompt "In a small village," \
189
+ --prompt "<user>Hello, who are you?<assistant>"
190
+
191
+ SFT_CONFIG="$OUTPUT_BASE/configs/sft.yaml"
192
+ cat > "$SFT_CONFIG" <<YAML
193
+ model:
194
+ architecture_type: taonet_ssm
195
+ vocab_size: 8192
196
+ hidden_dim: 1024
197
+ num_layers: 18
198
+ num_heads: 8
199
+ max_seq_length: $SEQ_LEN
200
+ d_latent_kv: 768
201
+ d_rope: 128
202
+ hidden_dim_ff: 3072
203
+ dropout: 0.0
204
+ gqa_groups: 1
205
+ use_factorized_embedding: false
206
+ d_embed_rank: 96
207
+ init_std: 0.02
208
+ ssm_core: dplr
209
+ ssm_hidden_dim: 32
210
+ ssm_mixer_dim: 256
211
+ ssm_num_lanes: 2
212
+ ssm_lane_combine: channel
213
+ ssm_lane_mode: split
214
+ ssm_split_mix: none
215
+ ssm_rank: 1
216
+ ssm_max_low_rank_scale: 0.1
217
+ ssm_finite_tail_correction: false
218
+ ssm_discretization: bilinear
219
+ ssm_kernel_mode: conv
220
+ ssm_kernel_threshold: 64
221
+ ssm_dt_min: 1e-3
222
+ ssm_dt_max: 1e-1
223
+ ssm_dt_init: 1e-2
224
+ ssm_use_d: true
225
+ ssm_activation: gelu
226
+ ssm_gate: true
227
+ ssm_input_gate: true
228
+ ssm_gate_type: channel
229
+ ssm_use_padding_mask: false
230
+ ssm_layer_scale_init: 0.1
231
+ ssm_branch_rms_norm: true
232
+ block_residual_rms_norm: false
233
+ $cap_yaml
234
+ ssm_local_shift: true
235
+ ssm_local_shift_init: 0.1
236
+ ssm_local_shift_per_channel: true
237
+
238
+ dataset:
239
+ split: train
240
+ instruction_column: input
241
+ response_column: output
242
+ local: true
243
+ jsonl_path: $SFT_DATA_PATH
244
+ samples_per_chunk: $SAMPLES_PER_CHUNK
245
+ tokenizer_type: sentencepiece
246
+ tokenizer_path: $TOKENIZER_PATH
247
+ tokenizer_threads: $TOKENIZER_THREADS
248
+
249
+ checkpoint_path: $PRETRAIN_CKPT
250
+ user_token: "<user>"
251
+ assistant_token: "<assistant>"
252
+ response_loss_only: true
253
+
254
+ batch_size: $BATCH_SIZE
255
+ num_epochs: 100000
256
+ max_steps: $SFT_STEPS
257
+ gradient_accumulation_steps: 1
258
+ max_grad_norm: 1.0
259
+
260
+ optimizer:
261
+ optimizer_type: adamw
262
+ learning_rate: $SFT_LR
263
+ weight_decay: 0.0
264
+ betas: [0.9, 0.999]
265
+ eps: 1e-8
266
+
267
+ scheduler:
268
+ scheduler_type: constant
269
+ warmup_steps: 100
270
+ warmup_ratio: 0.0
271
+
272
+ dtype: bfloat16
273
+ device: cuda
274
+ checkpoint_dir: $CHECKPOINT_BASE/sft
275
+ save_every_steps: $SFT_SAVE_EVERY
276
+ save_best_model: false
277
+ keep_last_n_checkpoints: 3
278
+ eval_every_steps: $SFT_SAVE_EVERY
279
+ eval_samples: 32
280
+ log_every_steps: 20
281
+ aim_repo: $OUTPUT_BASE/.aim-sft
282
+ seed: 44
283
+ num_workers: 0
284
+ pin_memory: true
285
+ YAML
286
+
287
+ printf '\n============================================================\n'
288
+ printf '200M branch-only pure SSM response-only SFT\n'
289
+ printf 'steps=%s batch=%s seq_len=%s save_every=%s\n' \
290
+ "$SFT_STEPS" "$BATCH_SIZE" "$SEQ_LEN" "$SFT_SAVE_EVERY"
291
+ printf '============================================================\n'
292
+
293
+ "$PYTHON_BIN" -m taoTrain.cli sft --config "$SFT_CONFIG"
294
+
295
+ SFT_CKPT="$CHECKPOINT_BASE/sft/final_model.pt"
296
+ if [[ ! -f "$SFT_CKPT" ]]; then
297
+ echo "Expected SFT checkpoint missing: $SFT_CKPT" >&2
298
+ exit 2
299
+ fi
300
+
301
+ "$PYTHON_BIN" scripts/diagnostics/generate_checkpoint_samples.py \
302
+ --checkpoint "$SFT_CKPT" \
303
+ --tokenizer-path "$TOKENIZER_PATH" \
304
+ --output "$OUTPUT_BASE/diagnostics/generation_samples_sft_final.json" \
305
+ --max-new-tokens 220 \
306
+ --temperature 0.7 \
307
+ --top-p 0.9 \
308
+ --prompt "<user>Hello, who are you?<assistant>" \
309
+ --prompt "<user>Explain what artificial intelligence is in simple words.<assistant>" \
310
+ --prompt "<user>Give me three practical study tips.<assistant>"
311
+
312
+ cat > "$OUTPUT_BASE/model_card.json" <<JSON
313
+ {
314
+ "state": "completed",
315
+ "architecture": "taonet_ssm",
316
+ "candidate": "pure_ssm_196m_branch_rms_only",
317
+ "pretrain_checkpoint": "$PRETRAIN_CKPT",
318
+ "sft_checkpoint": "$SFT_CKPT",
319
+ "tokenizer_path": "$TOKENIZER_PATH",
320
+ "pretrain_config": "$PRETRAIN_CONFIG",
321
+ "sft_config": "$SFT_CONFIG",
322
+ "activation_probe": "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_final.json",
323
+ "pretrain_samples": "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_final.json",
324
+ "sft_samples": "$OUTPUT_BASE/diagnostics/generation_samples_sft_final.json"
325
+ }
326
+ JSON
327
+
328
+ echo "200M branch-only pure SSM pretrain+SFT run completed."
329
+ cat "$OUTPUT_BASE/model_card.json"
code/TaoTrain/scripts/remote/run_200m_until_selection.sh ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ ROOT_OUTPUT_DIR="${REPOBRIDGE_OUTPUT_DIR:-$(pwd)/results/200m-until-selection}"
5
+ ROOT_CHECKPOINT_DIR="${TAOTERN_CHECKPOINT_DIR:-$ROOT_OUTPUT_DIR/checkpoints}"
6
+ SEQ_LEN="${SEQ_LEN:-512}"
7
+ BATCH_SIZE="${BATCH_SIZE:-8}"
8
+ PILOT_TOKENS="${PILOT_TOKENS:-300000000}"
9
+ SERIOUS_TOKENS="${SERIOUS_TOKENS:-1000000000}"
10
+ PILOT_EVAL_BATCHES="${PILOT_EVAL_BATCHES:-64}"
11
+ SERIOUS_EVAL_BATCHES="${SERIOUS_EVAL_BATCHES:-128}"
12
+ LEARNING_RATE="${LEARNING_RATE:-0.0003}"
13
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
14
+
15
+ ceil_div() {
16
+ local numerator="$1"
17
+ local denominator="$2"
18
+ echo $(( (numerator + denominator - 1) / denominator ))
19
+ }
20
+
21
+ run_phase() {
22
+ local phase="$1"
23
+ local target_tokens="$2"
24
+ local eval_batches="$3"
25
+ local tokens_per_step=$((BATCH_SIZE * SEQ_LEN))
26
+ local train_steps
27
+ train_steps="$(ceil_div "$target_tokens" "$tokens_per_step")"
28
+
29
+ printf '\n============================================================\n'
30
+ printf '200M until-selection phase: %s\n' "$phase"
31
+ printf 'target_tokens=%s batch=%s seq_len=%s train_steps=%s eval_batches=%s\n' \
32
+ "$target_tokens" "$BATCH_SIZE" "$SEQ_LEN" "$train_steps" "$eval_batches"
33
+ printf '============================================================\n'
34
+
35
+ REPOBRIDGE_OUTPUT_DIR="$ROOT_OUTPUT_DIR/$phase" \
36
+ TAOTERN_CHECKPOINT_DIR="$ROOT_CHECKPOINT_DIR/$phase" \
37
+ BATCH_SIZES="$BATCH_SIZE" \
38
+ SEQ_LEN="$SEQ_LEN" \
39
+ TRAIN_STEPS="$train_steps" \
40
+ EVAL_BATCHES="$eval_batches" \
41
+ LEARNING_RATE="$LEARNING_RATE" \
42
+ WEIGHT_DECAY="$WEIGHT_DECAY" \
43
+ bash scripts/remote/run_200m_base_suite.sh
44
+ }
45
+
46
+ mkdir -p "$ROOT_OUTPUT_DIR" "$ROOT_CHECKPOINT_DIR"
47
+
48
+ cat > "$ROOT_OUTPUT_DIR/run_plan.json" <<JSON
49
+ {
50
+ "stopping_point": "selection_after_1b_all_four_variants",
51
+ "batch_size": $BATCH_SIZE,
52
+ "seq_len": $SEQ_LEN,
53
+ "learning_rate": $LEARNING_RATE,
54
+ "weight_decay": $WEIGHT_DECAY,
55
+ "phases": [
56
+ {
57
+ "name": "pilot_300m",
58
+ "target_tokens_per_variant": $PILOT_TOKENS,
59
+ "train_steps": $(ceil_div "$PILOT_TOKENS" $((BATCH_SIZE * SEQ_LEN))),
60
+ "eval_batches": $PILOT_EVAL_BATCHES
61
+ },
62
+ {
63
+ "name": "serious_1b",
64
+ "target_tokens_per_variant": $SERIOUS_TOKENS,
65
+ "train_steps": $(ceil_div "$SERIOUS_TOKENS" $((BATCH_SIZE * SEQ_LEN))),
66
+ "eval_batches": $SERIOUS_EVAL_BATCHES
67
+ }
68
+ ],
69
+ "variants": [
70
+ "attention_196m",
71
+ "pure_ssm_196m_hadamard",
72
+ "pure_ssm_196m_nomix",
73
+ "hybrid_ssm_first_199m"
74
+ ]
75
+ }
76
+ JSON
77
+
78
+ run_phase "pilot_300m" "$PILOT_TOKENS" "$PILOT_EVAL_BATCHES"
79
+ run_phase "serious_1b" "$SERIOUS_TOKENS" "$SERIOUS_EVAL_BATCHES"
80
+
81
+ echo "Selection gate reached after pilot_300m and serious_1b completed for all four variants."
code/TaoTrain/scripts/remote/run_pre_200m_branch_only_gate.sh ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
5
+ SFT_DATA_PATH="${SFT_DATA_PATH:-/home/student/Data/TaoData/sft.jsonl}"
6
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
7
+ SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
8
+ PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
9
+ REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
10
+ OUTPUT_BASE="${REPOBRIDGE_OUTPUT_DIR:-$REMOTE_REPO/results/pre-200m-branch-only-gate}"
11
+ CHECKPOINT_BASE="${TAOTERN_CHECKPOINT_DIR:-$OUTPUT_BASE/checkpoints}"
12
+
13
+ SEQ_LEN="${SEQ_LEN:-512}"
14
+ BATCH_SIZE="${BATCH_SIZE:-8}"
15
+ TARGET_TOKENS="${TARGET_TOKENS:-100000000}"
16
+ MAX_TOKENS="${MAX_TOKENS:-100000000}"
17
+ MAX_RECORDS="${MAX_RECORDS:-200000}"
18
+ EVAL_BATCHES="${EVAL_BATCHES:-128}"
19
+ LEARNING_RATE="${LEARNING_RATE:-0.0008}"
20
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
21
+ TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-1000}"
22
+ SFT_SANITY_SAMPLES="${SFT_SANITY_SAMPLES:-4}"
23
+ SFT_SANITY_STEPS="${SFT_SANITY_STEPS:-120}"
24
+ SFT_SANITY_LR="${SFT_SANITY_LR:-0.00005}"
25
+ BLOCK_RESIDUAL_RMS_CAP="${BLOCK_RESIDUAL_RMS_CAP:-}"
26
+
27
+ ceil_div() {
28
+ local numerator="$1"
29
+ local denominator="$2"
30
+ echo $(( (numerator + denominator - 1) / denominator ))
31
+ }
32
+
33
+ TRAIN_STEPS="${TRAIN_STEPS:-$(ceil_div "$TARGET_TOKENS" $((BATCH_SIZE * SEQ_LEN)))}"
34
+
35
+ export PYTHONPATH="$REMOTE_REPO/src:$SSM_REPO_PATH"
36
+ mkdir -p "$OUTPUT_BASE" "$CHECKPOINT_BASE" "$OUTPUT_BASE/configs" "$OUTPUT_BASE/diagnostics"
37
+
38
+ cap_args=""
39
+ cap_yaml=""
40
+ if [[ -n "$BLOCK_RESIDUAL_RMS_CAP" ]]; then
41
+ cap_args="--block-residual-rms-cap $BLOCK_RESIDUAL_RMS_CAP"
42
+ cap_yaml=" block_residual_rms_cap: $BLOCK_RESIDUAL_RMS_CAP"
43
+ fi
44
+
45
+ cat > "$OUTPUT_BASE/run_plan.json" <<JSON
46
+ {
47
+ "purpose": "pre_200m_branch_only_quality_gate_before_4b_sft_chatbot_run",
48
+ "candidate": "pure_ssm_196m_branch_rms_only",
49
+ "target_tokens": $TARGET_TOKENS,
50
+ "train_steps": $TRAIN_STEPS,
51
+ "batch_size": $BATCH_SIZE,
52
+ "seq_len": $SEQ_LEN,
53
+ "learning_rate": $LEARNING_RATE,
54
+ "weight_decay": $WEIGHT_DECAY,
55
+ "block_residual_rms_cap": "${BLOCK_RESIDUAL_RMS_CAP:-null}",
56
+ "checks": [
57
+ "bounded pretrain loss/eval/grad telemetry",
58
+ "activation scale probe",
59
+ "sample generation",
60
+ "tiny SFT overfit probe"
61
+ ]
62
+ }
63
+ JSON
64
+
65
+ printf '\n============================================================\n'
66
+ printf 'Pre-200M branch-only gate: pure SSM stabilized candidate\n'
67
+ printf 'target_tokens=%s batch=%s seq_len=%s train_steps=%s eval_batches=%s cap=%s\n' \
68
+ "$TARGET_TOKENS" "$BATCH_SIZE" "$SEQ_LEN" "$TRAIN_STEPS" "$EVAL_BATCHES" "${BLOCK_RESIDUAL_RMS_CAP:-none}"
69
+ printf '============================================================\n'
70
+
71
+ "$PYTHON_BIN" scripts/benchmark_taonet_real_tokens.py \
72
+ --data-path "$DATA_PATH" \
73
+ --text-field text \
74
+ --tokenizer-type sentencepiece \
75
+ --tokenizer-path "$TOKENIZER_PATH" \
76
+ --max-records "$MAX_RECORDS" \
77
+ --max-tokens "$MAX_TOKENS" \
78
+ --eval-fraction 0.1 \
79
+ --architectures taonet_ssm \
80
+ --batch-sizes "$BATCH_SIZE" \
81
+ --seq-len "$SEQ_LEN" \
82
+ --hidden-dim 1024 \
83
+ --num-layers 18 \
84
+ --num-heads 8 \
85
+ --d-latent-kv 768 \
86
+ --d-rope 128 \
87
+ --hidden-dim-ff 3072 \
88
+ --dropout 0.0 \
89
+ --ssm-core dplr \
90
+ --ssm-hidden-dims 32 \
91
+ --ssm-mixer-dims 256 \
92
+ --ssm-num-lanes-list 2 \
93
+ --ssm-lane-combine channel \
94
+ --ssm-lane-modes split \
95
+ --ssm-split-mixes none \
96
+ --ssm-rank 1 \
97
+ --ssm-kernel-mode conv \
98
+ --no-ssm-finite-tail-correction \
99
+ --ssm-gate-types channel \
100
+ --dtype bf16 \
101
+ --device cuda \
102
+ --warmup 1 \
103
+ --repeats 2 \
104
+ --backward \
105
+ --train-steps "$TRAIN_STEPS" \
106
+ --train-log-every "$TRAIN_LOG_EVERY" \
107
+ --learning-rate "$LEARNING_RATE" \
108
+ --weight-decay "$WEIGHT_DECAY" \
109
+ --max-grad-norm 1.0 \
110
+ --eval-batches "$EVAL_BATCHES" \
111
+ --ssm-local-shift \
112
+ --ssm-local-shift-per-channel \
113
+ --ssm-local-shift-init 0.1 \
114
+ --ssm-branch-rms-norm \
115
+ $cap_args \
116
+ --output-dir "$OUTPUT_BASE/pretrain" \
117
+ --resume-completed \
118
+ --incremental-output \
119
+ --save-case-checkpoints \
120
+ --checkpoint-dir "$CHECKPOINT_BASE/pretrain"
121
+
122
+ PRETRAIN_CKPT="$CHECKPOINT_BASE/pretrain/latest.pt"
123
+ if [[ ! -f "$PRETRAIN_CKPT" ]]; then
124
+ echo "Expected pretrain checkpoint missing: $PRETRAIN_CKPT" >&2
125
+ exit 2
126
+ fi
127
+
128
+ "$PYTHON_BIN" scripts/diagnostics/activation_probe.py \
129
+ --checkpoint "$PRETRAIN_CKPT" \
130
+ --tokenizer-path "$TOKENIZER_PATH" \
131
+ --data-path "$DATA_PATH" \
132
+ --text-field text \
133
+ --output "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_latest.json" \
134
+ --batch-size 2 \
135
+ --seq-len "$SEQ_LEN" \
136
+ --device cuda \
137
+ --dtype bfloat16
138
+
139
+ "$PYTHON_BIN" scripts/diagnostics/generate_checkpoint_samples.py \
140
+ --checkpoint "$PRETRAIN_CKPT" \
141
+ --tokenizer-path "$TOKENIZER_PATH" \
142
+ --output "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_latest.json" \
143
+ --max-new-tokens 120 \
144
+ --temperature 0.8 \
145
+ --top-p 0.9 \
146
+ --prompt "The purpose of artificial intelligence is" \
147
+ --prompt "In a small village," \
148
+ --prompt "<user>Hello, who are you?<assistant>"
149
+
150
+ SFT_CONFIG="$OUTPUT_BASE/configs/sft_sanity.yaml"
151
+ cat > "$SFT_CONFIG" <<YAML
152
+ model:
153
+ architecture_type: taonet_ssm
154
+ vocab_size: 8192
155
+ hidden_dim: 1024
156
+ num_layers: 18
157
+ num_heads: 8
158
+ max_seq_length: $SEQ_LEN
159
+ d_latent_kv: 768
160
+ d_rope: 128
161
+ hidden_dim_ff: 3072
162
+ dropout: 0.0
163
+ gqa_groups: 1
164
+ use_factorized_embedding: false
165
+ d_embed_rank: 96
166
+ init_std: 0.02
167
+ ssm_core: dplr
168
+ ssm_hidden_dim: 32
169
+ ssm_mixer_dim: 256
170
+ ssm_num_lanes: 2
171
+ ssm_lane_combine: channel
172
+ ssm_lane_mode: split
173
+ ssm_split_mix: none
174
+ ssm_rank: 1
175
+ ssm_max_low_rank_scale: 0.1
176
+ ssm_finite_tail_correction: false
177
+ ssm_discretization: bilinear
178
+ ssm_kernel_mode: conv
179
+ ssm_kernel_threshold: 64
180
+ ssm_dt_min: 1e-3
181
+ ssm_dt_max: 1e-1
182
+ ssm_dt_init: 1e-2
183
+ ssm_use_d: true
184
+ ssm_activation: gelu
185
+ ssm_gate: true
186
+ ssm_input_gate: true
187
+ ssm_gate_type: channel
188
+ ssm_use_padding_mask: false
189
+ ssm_layer_scale_init: 0.1
190
+ ssm_branch_rms_norm: true
191
+ block_residual_rms_norm: false
192
+ $cap_yaml
193
+ ssm_local_shift: true
194
+ ssm_local_shift_init: 0.1
195
+ ssm_local_shift_per_channel: true
196
+
197
+ dataset:
198
+ split: train
199
+ instruction_column: input
200
+ response_column: output
201
+ local: true
202
+ jsonl_path: $SFT_DATA_PATH
203
+ samples_per_chunk: 2000
204
+ tokenizer_type: sentencepiece
205
+ tokenizer_path: $TOKENIZER_PATH
206
+ tokenizer_threads: 8
207
+
208
+ checkpoint_path: $PRETRAIN_CKPT
209
+ user_token: "<user>"
210
+ assistant_token: "<assistant>"
211
+ response_loss_only: true
212
+
213
+ batch_size: $BATCH_SIZE
214
+ num_epochs: 100000
215
+ max_steps: $SFT_SANITY_STEPS
216
+ gradient_accumulation_steps: 1
217
+ max_grad_norm: 1.0
218
+
219
+ optimizer:
220
+ optimizer_type: adamw
221
+ learning_rate: $SFT_SANITY_LR
222
+ weight_decay: 0.0
223
+ betas: [0.9, 0.999]
224
+ eps: 1e-8
225
+
226
+ scheduler:
227
+ scheduler_type: linearWarmup
228
+ warmup_steps: 0
229
+
230
+ dtype: bfloat16
231
+ device: cuda
232
+ checkpoint_dir: $CHECKPOINT_BASE/sft_sanity
233
+ save_every_steps: 5000
234
+ save_best_model: false
235
+ keep_last_n_checkpoints: 1
236
+ eval_every_steps: 5000
237
+ eval_samples: 32
238
+ log_every_steps: 10
239
+ aim_repo: $OUTPUT_BASE/.aim-sft-sanity
240
+ seed: 43
241
+ num_workers: 0
242
+ pin_memory: true
243
+ YAML
244
+
245
+ "$PYTHON_BIN" scripts/diagnostics/sft_sanity_check.py \
246
+ --config "$SFT_CONFIG" \
247
+ --checkpoint "$PRETRAIN_CKPT" \
248
+ --output "$OUTPUT_BASE/diagnostics/sft_sanity_pretrain_latest.json" \
249
+ --samples "$SFT_SANITY_SAMPLES" \
250
+ --steps "$SFT_SANITY_STEPS" \
251
+ --lr "$SFT_SANITY_LR" \
252
+ --log-every 20 \
253
+ --device cuda \
254
+ --dtype bfloat16 \
255
+ --ssm-branch-rms-norm
256
+
257
+ cat > "$OUTPUT_BASE/gate_summary.json" <<JSON
258
+ {
259
+ "state": "completed",
260
+ "pretrain_checkpoint": "$PRETRAIN_CKPT",
261
+ "pretrain_results": "$OUTPUT_BASE/pretrain/taonet_real_token_benchmark.json",
262
+ "activation_probe": "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_latest.json",
263
+ "generation_samples": "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_latest.json",
264
+ "sft_sanity": "$OUTPUT_BASE/diagnostics/sft_sanity_pretrain_latest.json"
265
+ }
266
+ JSON
267
+
268
+ echo "Pre-200M branch-only gate completed."
269
+ cat "$OUTPUT_BASE/gate_summary.json"
code/TaoTrain/scripts/remote/run_pre_200m_stability_gate.sh ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
5
+ SFT_DATA_PATH="${SFT_DATA_PATH:-/home/student/Data/TaoData/sft.jsonl}"
6
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
7
+ SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
8
+ PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
9
+ REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
10
+ OUTPUT_BASE="${REPOBRIDGE_OUTPUT_DIR:-$REMOTE_REPO/results/pre-200m-stability-gate}"
11
+ CHECKPOINT_BASE="${TAOTERN_CHECKPOINT_DIR:-$OUTPUT_BASE/checkpoints}"
12
+
13
+ SEQ_LEN="${SEQ_LEN:-512}"
14
+ BATCH_SIZE="${BATCH_SIZE:-8}"
15
+ TARGET_TOKENS="${TARGET_TOKENS:-20000000}"
16
+ MAX_TOKENS="${MAX_TOKENS:-50000000}"
17
+ MAX_RECORDS="${MAX_RECORDS:-120000}"
18
+ EVAL_BATCHES="${EVAL_BATCHES:-64}"
19
+ LEARNING_RATE="${LEARNING_RATE:-0.0008}"
20
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
21
+ TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-250}"
22
+ SFT_SANITY_SAMPLES="${SFT_SANITY_SAMPLES:-4}"
23
+ SFT_SANITY_STEPS="${SFT_SANITY_STEPS:-120}"
24
+ SFT_SANITY_LR="${SFT_SANITY_LR:-0.00005}"
25
+
26
+ ceil_div() {
27
+ local numerator="$1"
28
+ local denominator="$2"
29
+ echo $(( (numerator + denominator - 1) / denominator ))
30
+ }
31
+
32
+ TRAIN_STEPS="${TRAIN_STEPS:-$(ceil_div "$TARGET_TOKENS" $((BATCH_SIZE * SEQ_LEN)))}"
33
+
34
+ export PYTHONPATH="$REMOTE_REPO/src:$SSM_REPO_PATH"
35
+ mkdir -p "$OUTPUT_BASE" "$CHECKPOINT_BASE" "$OUTPUT_BASE/configs" "$OUTPUT_BASE/diagnostics"
36
+
37
+ cat > "$OUTPUT_BASE/run_plan.json" <<JSON
38
+ {
39
+ "purpose": "pre_200m_stability_gate_before_4b_sft_chatbot_run",
40
+ "candidate": "pure_ssm_196m_stabilized_m256_h32",
41
+ "target_tokens": $TARGET_TOKENS,
42
+ "train_steps": $TRAIN_STEPS,
43
+ "batch_size": $BATCH_SIZE,
44
+ "seq_len": $SEQ_LEN,
45
+ "learning_rate": $LEARNING_RATE,
46
+ "weight_decay": $WEIGHT_DECAY,
47
+ "checks": [
48
+ "bounded pretrain loss/eval/grad telemetry",
49
+ "activation scale probe",
50
+ "sample generation",
51
+ "tiny SFT overfit probe"
52
+ ]
53
+ }
54
+ JSON
55
+
56
+ printf '\n============================================================\n'
57
+ printf 'Pre-200M stability gate: pure SSM stabilized candidate\n'
58
+ printf 'target_tokens=%s batch=%s seq_len=%s train_steps=%s eval_batches=%s\n' \
59
+ "$TARGET_TOKENS" "$BATCH_SIZE" "$SEQ_LEN" "$TRAIN_STEPS" "$EVAL_BATCHES"
60
+ printf '============================================================\n'
61
+
62
+ "$PYTHON_BIN" scripts/benchmark_taonet_real_tokens.py \
63
+ --data-path "$DATA_PATH" \
64
+ --text-field text \
65
+ --tokenizer-type sentencepiece \
66
+ --tokenizer-path "$TOKENIZER_PATH" \
67
+ --max-records "$MAX_RECORDS" \
68
+ --max-tokens "$MAX_TOKENS" \
69
+ --eval-fraction 0.1 \
70
+ --architectures taonet_ssm \
71
+ --batch-sizes "$BATCH_SIZE" \
72
+ --seq-len "$SEQ_LEN" \
73
+ --hidden-dim 1024 \
74
+ --num-layers 18 \
75
+ --num-heads 8 \
76
+ --d-latent-kv 768 \
77
+ --d-rope 128 \
78
+ --hidden-dim-ff 3072 \
79
+ --dropout 0.0 \
80
+ --ssm-core dplr \
81
+ --ssm-hidden-dims 32 \
82
+ --ssm-mixer-dims 256 \
83
+ --ssm-num-lanes-list 2 \
84
+ --ssm-lane-combine channel \
85
+ --ssm-lane-modes split \
86
+ --ssm-split-mixes none \
87
+ --ssm-rank 1 \
88
+ --ssm-kernel-mode conv \
89
+ --no-ssm-finite-tail-correction \
90
+ --ssm-gate-types channel \
91
+ --dtype bf16 \
92
+ --device cuda \
93
+ --warmup 1 \
94
+ --repeats 2 \
95
+ --backward \
96
+ --train-steps "$TRAIN_STEPS" \
97
+ --train-log-every "$TRAIN_LOG_EVERY" \
98
+ --learning-rate "$LEARNING_RATE" \
99
+ --weight-decay "$WEIGHT_DECAY" \
100
+ --max-grad-norm 1.0 \
101
+ --eval-batches "$EVAL_BATCHES" \
102
+ --ssm-local-shift \
103
+ --ssm-local-shift-per-channel \
104
+ --ssm-local-shift-init 0.1 \
105
+ --ssm-branch-rms-norm \
106
+ --ssm-branch-clip-value 1.0 \
107
+ --block-residual-rms-norm \
108
+ --block-residual-rms-target 1.0 \
109
+ --output-dir "$OUTPUT_BASE/pretrain" \
110
+ --resume-completed \
111
+ --incremental-output \
112
+ --save-case-checkpoints \
113
+ --checkpoint-dir "$CHECKPOINT_BASE/pretrain"
114
+
115
+ PRETRAIN_CKPT="$CHECKPOINT_BASE/pretrain/latest.pt"
116
+ if [[ ! -f "$PRETRAIN_CKPT" ]]; then
117
+ echo "Expected pretrain checkpoint missing: $PRETRAIN_CKPT" >&2
118
+ exit 2
119
+ fi
120
+
121
+ "$PYTHON_BIN" scripts/diagnostics/activation_probe.py \
122
+ --checkpoint "$PRETRAIN_CKPT" \
123
+ --tokenizer-path "$TOKENIZER_PATH" \
124
+ --data-path "$DATA_PATH" \
125
+ --text-field text \
126
+ --output "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_latest.json" \
127
+ --batch-size 2 \
128
+ --seq-len "$SEQ_LEN" \
129
+ --device cuda \
130
+ --dtype bfloat16
131
+
132
+ "$PYTHON_BIN" scripts/diagnostics/generate_checkpoint_samples.py \
133
+ --checkpoint "$PRETRAIN_CKPT" \
134
+ --tokenizer-path "$TOKENIZER_PATH" \
135
+ --output "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_latest.json" \
136
+ --max-new-tokens 80 \
137
+ --temperature 0.8 \
138
+ --top-p 0.9 \
139
+ --prompt "The purpose of artificial intelligence is" \
140
+ --prompt "In a small village," \
141
+ --prompt "<user>Hello, who are you?<assistant>"
142
+
143
+ SFT_CONFIG="$OUTPUT_BASE/configs/sft_sanity.yaml"
144
+ cat > "$SFT_CONFIG" <<YAML
145
+ model:
146
+ architecture_type: taonet_ssm
147
+ vocab_size: 8192
148
+ hidden_dim: 1024
149
+ num_layers: 18
150
+ num_heads: 8
151
+ max_seq_length: $SEQ_LEN
152
+ d_latent_kv: 768
153
+ d_rope: 128
154
+ hidden_dim_ff: 3072
155
+ dropout: 0.0
156
+ gqa_groups: 1
157
+ use_factorized_embedding: false
158
+ d_embed_rank: 96
159
+ init_std: 0.02
160
+ ssm_core: dplr
161
+ ssm_hidden_dim: 32
162
+ ssm_mixer_dim: 256
163
+ ssm_num_lanes: 2
164
+ ssm_lane_combine: channel
165
+ ssm_lane_mode: split
166
+ ssm_split_mix: none
167
+ ssm_rank: 1
168
+ ssm_max_low_rank_scale: 0.1
169
+ ssm_finite_tail_correction: false
170
+ ssm_discretization: bilinear
171
+ ssm_kernel_mode: conv
172
+ ssm_kernel_threshold: 64
173
+ ssm_dt_min: 1e-3
174
+ ssm_dt_max: 1e-1
175
+ ssm_dt_init: 1e-2
176
+ ssm_use_d: true
177
+ ssm_activation: gelu
178
+ ssm_gate: true
179
+ ssm_input_gate: true
180
+ ssm_gate_type: channel
181
+ ssm_use_padding_mask: false
182
+ ssm_layer_scale_init: 0.1
183
+ ssm_branch_rms_norm: true
184
+ ssm_branch_clip_value: 1.0
185
+ block_residual_rms_norm: true
186
+ block_residual_rms_target: 1.0
187
+ ssm_local_shift: true
188
+ ssm_local_shift_init: 0.1
189
+ ssm_local_shift_per_channel: true
190
+
191
+ dataset:
192
+ split: train
193
+ instruction_column: input
194
+ response_column: output
195
+ local: true
196
+ jsonl_path: $SFT_DATA_PATH
197
+ samples_per_chunk: 2000
198
+ tokenizer_type: sentencepiece
199
+ tokenizer_path: $TOKENIZER_PATH
200
+ tokenizer_threads: 8
201
+
202
+ checkpoint_path: $PRETRAIN_CKPT
203
+ user_token: "<user>"
204
+ assistant_token: "<assistant>"
205
+ response_loss_only: true
206
+
207
+ batch_size: $BATCH_SIZE
208
+ num_epochs: 100000
209
+ max_steps: $SFT_SANITY_STEPS
210
+ gradient_accumulation_steps: 1
211
+ max_grad_norm: 1.0
212
+
213
+ optimizer:
214
+ optimizer_type: adamw
215
+ learning_rate: $SFT_SANITY_LR
216
+ weight_decay: 0.0
217
+ betas: [0.9, 0.999]
218
+ eps: 1e-8
219
+
220
+ scheduler:
221
+ scheduler_type: linearWarmup
222
+ warmup_steps: 0
223
+
224
+ dtype: bfloat16
225
+ device: cuda
226
+ checkpoint_dir: $CHECKPOINT_BASE/sft_sanity
227
+ save_every_steps: 5000
228
+ save_best_model: false
229
+ keep_last_n_checkpoints: 1
230
+ eval_every_steps: 5000
231
+ eval_samples: 32
232
+ log_every_steps: 10
233
+ aim_repo: $OUTPUT_BASE/.aim-sft-sanity
234
+ seed: 43
235
+ num_workers: 0
236
+ pin_memory: true
237
+ YAML
238
+
239
+ "$PYTHON_BIN" scripts/diagnostics/sft_sanity_check.py \
240
+ --config "$SFT_CONFIG" \
241
+ --checkpoint "$PRETRAIN_CKPT" \
242
+ --output "$OUTPUT_BASE/diagnostics/sft_sanity_pretrain_latest.json" \
243
+ --samples "$SFT_SANITY_SAMPLES" \
244
+ --steps "$SFT_SANITY_STEPS" \
245
+ --lr "$SFT_SANITY_LR" \
246
+ --log-every 20 \
247
+ --device cuda \
248
+ --dtype bfloat16 \
249
+ --ssm-branch-rms-norm \
250
+ --ssm-branch-clip-value 1.0 \
251
+ --block-residual-rms-norm \
252
+ --block-residual-rms-target 1.0
253
+
254
+ cat > "$OUTPUT_BASE/gate_summary.json" <<JSON
255
+ {
256
+ "state": "completed",
257
+ "pretrain_checkpoint": "$PRETRAIN_CKPT",
258
+ "pretrain_results": "$OUTPUT_BASE/pretrain/taonet_real_token_benchmark.json",
259
+ "activation_probe": "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_latest.json",
260
+ "generation_samples": "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_latest.json",
261
+ "sft_sanity": "$OUTPUT_BASE/diagnostics/sft_sanity_pretrain_latest.json"
262
+ }
263
+ JSON
264
+
265
+ echo "Pre-200M stability gate completed."
266
+ cat "$OUTPUT_BASE/gate_summary.json"
267
+
code/TaoTrain/scripts/remote/submit_200m_base_suite.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ RUN_ID="${RUN_ID:-taotern-200m-base-suite-$(date +%Y%m%d-%H%M%S)}"
5
+ JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
6
+ REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
7
+ PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
8
+ SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
9
+
10
+ JOB_COMMAND="REMOTE_REPO=$REMOTE_REPO PYTHON_BIN=$PYTHON_BIN SSM_REPO_PATH=$SSM_REPO_PATH bash scripts/remote/run_200m_base_suite.sh"
11
+
12
+ export RUN_ID JOB_ROOT JOB_COMMAND
13
+ export OUTPUT_DIR="${OUTPUT_DIR:-$JOB_ROOT/$RUN_ID/outputs}"
14
+ export CHECKPOINT_DIR="${CHECKPOINT_DIR:-$JOB_ROOT/$RUN_ID/checkpoints}"
15
+
16
+ bash scripts/remote/submit_detached_job.sh
code/TaoTrain/scripts/remote/submit_200m_branch_only_chat.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ RUN_ID="${RUN_ID:-taotern-200m-branch-only-chat-$(date +%Y%m%d-%H%M%S)}"
5
+ JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
6
+ REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
7
+ PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
8
+ SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
9
+ DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
10
+ SFT_DATA_PATH="${SFT_DATA_PATH:-/home/student/Data/TaoData/sft.jsonl}"
11
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
12
+ SEQ_LEN="${SEQ_LEN:-512}"
13
+ BATCH_SIZE="${BATCH_SIZE:-8}"
14
+ PRETRAIN_TOKENS="${PRETRAIN_TOKENS:-4000000000}"
15
+ SFT_STEPS="${SFT_STEPS:-50000}"
16
+ PRETRAIN_LR="${PRETRAIN_LR:-0.0008}"
17
+ SFT_LR="${SFT_LR:-0.00005}"
18
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
19
+ LOG_EVERY="${LOG_EVERY:-100}"
20
+ SAVE_EVERY="${SAVE_EVERY:-100000}"
21
+ SFT_SAVE_EVERY="${SFT_SAVE_EVERY:-10000}"
22
+ TOKENIZER_THREADS="${TOKENIZER_THREADS:-8}"
23
+ SAMPLES_PER_CHUNK="${SAMPLES_PER_CHUNK:-2000}"
24
+ BLOCK_RESIDUAL_RMS_CAP="${BLOCK_RESIDUAL_RMS_CAP:-}"
25
+
26
+ JOB_COMMAND="REMOTE_REPO=$REMOTE_REPO PYTHON_BIN=$PYTHON_BIN SSM_REPO_PATH=$SSM_REPO_PATH DATA_PATH=$DATA_PATH SFT_DATA_PATH=$SFT_DATA_PATH TOKENIZER_PATH=$TOKENIZER_PATH SEQ_LEN=$SEQ_LEN BATCH_SIZE=$BATCH_SIZE PRETRAIN_TOKENS=$PRETRAIN_TOKENS SFT_STEPS=$SFT_STEPS PRETRAIN_LR=$PRETRAIN_LR SFT_LR=$SFT_LR WEIGHT_DECAY=$WEIGHT_DECAY LOG_EVERY=$LOG_EVERY SAVE_EVERY=$SAVE_EVERY SFT_SAVE_EVERY=$SFT_SAVE_EVERY TOKENIZER_THREADS=$TOKENIZER_THREADS SAMPLES_PER_CHUNK=$SAMPLES_PER_CHUNK BLOCK_RESIDUAL_RMS_CAP=$BLOCK_RESIDUAL_RMS_CAP bash scripts/remote/run_200m_branch_only_chat.sh"
27
+
28
+ export RUN_ID JOB_ROOT JOB_COMMAND
29
+ export OUTPUT_DIR="${OUTPUT_DIR:-$JOB_ROOT/$RUN_ID/outputs}"
30
+ export CHECKPOINT_DIR="${CHECKPOINT_DIR:-$JOB_ROOT/$RUN_ID/checkpoints}"
31
+
32
+ bash scripts/remote/submit_detached_job.sh
code/TaoTrain/scripts/remote/submit_200m_until_selection.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ RUN_ID="${RUN_ID:-taotern-200m-until-selection-$(date +%Y%m%d-%H%M%S)}"
5
+ JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
6
+ REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
7
+ PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
8
+ SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
9
+ DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
10
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
11
+ SEQ_LEN="${SEQ_LEN:-512}"
12
+ BATCH_SIZE="${BATCH_SIZE:-8}"
13
+ PILOT_TOKENS="${PILOT_TOKENS:-300000000}"
14
+ SERIOUS_TOKENS="${SERIOUS_TOKENS:-1000000000}"
15
+ PILOT_EVAL_BATCHES="${PILOT_EVAL_BATCHES:-64}"
16
+ SERIOUS_EVAL_BATCHES="${SERIOUS_EVAL_BATCHES:-128}"
17
+ MAX_TOKENS="${MAX_TOKENS:-50000000}"
18
+ MAX_RECORDS="${MAX_RECORDS:-100000}"
19
+ LEARNING_RATE="${LEARNING_RATE:-0.0003}"
20
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
21
+
22
+ JOB_COMMAND="REMOTE_REPO=$REMOTE_REPO PYTHON_BIN=$PYTHON_BIN SSM_REPO_PATH=$SSM_REPO_PATH DATA_PATH=$DATA_PATH TOKENIZER_PATH=$TOKENIZER_PATH SEQ_LEN=$SEQ_LEN BATCH_SIZE=$BATCH_SIZE PILOT_TOKENS=$PILOT_TOKENS SERIOUS_TOKENS=$SERIOUS_TOKENS PILOT_EVAL_BATCHES=$PILOT_EVAL_BATCHES SERIOUS_EVAL_BATCHES=$SERIOUS_EVAL_BATCHES MAX_TOKENS=$MAX_TOKENS MAX_RECORDS=$MAX_RECORDS LEARNING_RATE=$LEARNING_RATE WEIGHT_DECAY=$WEIGHT_DECAY bash scripts/remote/run_200m_until_selection.sh"
23
+
24
+ export RUN_ID JOB_ROOT JOB_COMMAND
25
+ export OUTPUT_DIR="${OUTPUT_DIR:-$JOB_ROOT/$RUN_ID/outputs}"
26
+ export CHECKPOINT_DIR="${CHECKPOINT_DIR:-$JOB_ROOT/$RUN_ID/checkpoints}"
27
+
28
+ bash scripts/remote/submit_detached_job.sh
code/TaoTrain/scripts/remote/submit_detached_job.sh ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ RUN_ID="${RUN_ID:-}"
5
+ JOB_COMMAND="${JOB_COMMAND:-}"
6
+ JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
7
+ SESSION_NAME="${SESSION_NAME:-}"
8
+
9
+ if [[ -z "$RUN_ID" ]]; then
10
+ echo "RUN_ID is required" >&2
11
+ exit 2
12
+ fi
13
+
14
+ if [[ -z "$JOB_COMMAND" ]]; then
15
+ echo "JOB_COMMAND is required" >&2
16
+ exit 2
17
+ fi
18
+
19
+ if ! command -v tmux >/dev/null 2>&1; then
20
+ echo "tmux is required for detached jobs" >&2
21
+ exit 2
22
+ fi
23
+
24
+ safe_run_id="$(printf '%s' "$RUN_ID" | tr -c 'A-Za-z0-9_.-' '_')"
25
+ session="${SESSION_NAME:-taotern_${safe_run_id}}"
26
+ job_dir="${JOB_ROOT%/}/${safe_run_id}"
27
+ output_dir="${OUTPUT_DIR:-$job_dir/outputs}"
28
+ checkpoint_dir="${CHECKPOINT_DIR:-$job_dir/checkpoints}"
29
+
30
+ mkdir -p "$job_dir" "$output_dir" "$checkpoint_dir"
31
+
32
+ cat > "$job_dir/command.sh" <<EOF
33
+ #!/usr/bin/env bash
34
+ set -euo pipefail
35
+ export REPOBRIDGE_OUTPUT_DIR="$output_dir"
36
+ export TAOTERN_CHECKPOINT_DIR="$checkpoint_dir"
37
+ cd "$(pwd)"
38
+ $JOB_COMMAND
39
+ EOF
40
+ chmod +x "$job_dir/command.sh"
41
+
42
+ cat > "$job_dir/run.sh" <<'EOF'
43
+ #!/usr/bin/env bash
44
+ set -Eeuo pipefail
45
+
46
+ job_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
47
+ status_json="$job_dir/status.json"
48
+ log_path="$job_dir/train.log"
49
+ telemetry_path="$job_dir/gpu_telemetry_nvidia_smi.csv"
50
+
51
+ write_status() {
52
+ local state="$1"
53
+ local extra="${2:-}"
54
+ local now
55
+ now="$(date -Iseconds)"
56
+ cat > "$status_json" <<STATUS
57
+ {"state":"$state","updated_at":"$now","job_dir":"$job_dir"$extra}
58
+ STATUS
59
+ }
60
+
61
+ write_status "running" ",\"start_time\":\"$(date -Iseconds)\""
62
+ rm -f "$job_dir/DONE" "$job_dir/FAILED"
63
+
64
+ monitor_pid=""
65
+ if command -v nvidia-smi >/dev/null 2>&1; then
66
+ (
67
+ while true; do
68
+ date -Iseconds
69
+ nvidia-smi --query-gpu=timestamp,name,utilization.gpu,utilization.memory,memory.used,memory.total,power.draw,temperature.gpu --format=csv,noheader,nounits
70
+ sleep 5
71
+ done
72
+ ) > "$telemetry_path" 2>&1 &
73
+ monitor_pid="$!"
74
+ fi
75
+
76
+ cleanup() {
77
+ if [[ -n "$monitor_pid" ]]; then
78
+ kill "$monitor_pid" 2>/dev/null || true
79
+ wait "$monitor_pid" 2>/dev/null || true
80
+ fi
81
+ }
82
+ trap cleanup EXIT
83
+
84
+ set +e
85
+ "$job_dir/command.sh" > "$log_path" 2>&1
86
+ exit_code="$?"
87
+ set -e
88
+
89
+ if [[ "$exit_code" -eq 0 ]]; then
90
+ touch "$job_dir/DONE"
91
+ write_status "completed" ",\"exit_code\":0,\"end_time\":\"$(date -Iseconds)\""
92
+ else
93
+ echo "$exit_code" > "$job_dir/FAILED"
94
+ write_status "failed" ",\"exit_code\":$exit_code,\"end_time\":\"$(date -Iseconds)\""
95
+ fi
96
+
97
+ exit "$exit_code"
98
+ EOF
99
+ chmod +x "$job_dir/run.sh"
100
+
101
+ cat > "$job_dir/status.json" <<STATUS
102
+ {"state":"submitted","updated_at":"$(date -Iseconds)","job_dir":"$job_dir","session":"$session","output_dir":"$output_dir","checkpoint_dir":"$checkpoint_dir"}
103
+ STATUS
104
+
105
+ if tmux has-session -t "$session" 2>/dev/null; then
106
+ echo "tmux session already exists: $session" >&2
107
+ echo "$job_dir"
108
+ exit 3
109
+ fi
110
+
111
+ tmux new-session -d -s "$session" "bash '$job_dir/run.sh'"
112
+
113
+ echo "Submitted detached job"
114
+ echo " run_id: $safe_run_id"
115
+ echo " session: $session"
116
+ echo " job_dir: $job_dir"
117
+ echo " output_dir: $output_dir"
118
+ echo " checkpoint_dir: $checkpoint_dir"
code/TaoTrain/scripts/remote/submit_pre_200m_branch_only_gate.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ RUN_ID="${RUN_ID:-taotern-pre-200m-branch-only-gate-$(date +%Y%m%d-%H%M%S)}"
5
+ JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
6
+ REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
7
+ PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
8
+ SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
9
+ DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
10
+ SFT_DATA_PATH="${SFT_DATA_PATH:-/home/student/Data/TaoData/sft.jsonl}"
11
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
12
+ SEQ_LEN="${SEQ_LEN:-512}"
13
+ BATCH_SIZE="${BATCH_SIZE:-8}"
14
+ TARGET_TOKENS="${TARGET_TOKENS:-100000000}"
15
+ MAX_TOKENS="${MAX_TOKENS:-100000000}"
16
+ MAX_RECORDS="${MAX_RECORDS:-200000}"
17
+ EVAL_BATCHES="${EVAL_BATCHES:-128}"
18
+ LEARNING_RATE="${LEARNING_RATE:-0.0008}"
19
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
20
+ TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-1000}"
21
+ SFT_SANITY_SAMPLES="${SFT_SANITY_SAMPLES:-4}"
22
+ SFT_SANITY_STEPS="${SFT_SANITY_STEPS:-120}"
23
+ SFT_SANITY_LR="${SFT_SANITY_LR:-0.00005}"
24
+ BLOCK_RESIDUAL_RMS_CAP="${BLOCK_RESIDUAL_RMS_CAP:-}"
25
+
26
+ JOB_COMMAND="REMOTE_REPO=$REMOTE_REPO PYTHON_BIN=$PYTHON_BIN SSM_REPO_PATH=$SSM_REPO_PATH DATA_PATH=$DATA_PATH SFT_DATA_PATH=$SFT_DATA_PATH TOKENIZER_PATH=$TOKENIZER_PATH SEQ_LEN=$SEQ_LEN BATCH_SIZE=$BATCH_SIZE TARGET_TOKENS=$TARGET_TOKENS MAX_TOKENS=$MAX_TOKENS MAX_RECORDS=$MAX_RECORDS EVAL_BATCHES=$EVAL_BATCHES LEARNING_RATE=$LEARNING_RATE WEIGHT_DECAY=$WEIGHT_DECAY TRAIN_LOG_EVERY=$TRAIN_LOG_EVERY SFT_SANITY_SAMPLES=$SFT_SANITY_SAMPLES SFT_SANITY_STEPS=$SFT_SANITY_STEPS SFT_SANITY_LR=$SFT_SANITY_LR BLOCK_RESIDUAL_RMS_CAP=$BLOCK_RESIDUAL_RMS_CAP bash scripts/remote/run_pre_200m_branch_only_gate.sh"
27
+
28
+ export RUN_ID JOB_ROOT JOB_COMMAND
29
+ export OUTPUT_DIR="${OUTPUT_DIR:-$JOB_ROOT/$RUN_ID/outputs}"
30
+ export CHECKPOINT_DIR="${CHECKPOINT_DIR:-$JOB_ROOT/$RUN_ID/checkpoints}"
31
+
32
+ bash scripts/remote/submit_detached_job.sh
33
+
code/TaoTrain/scripts/remote/submit_pre_200m_stability_gate.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ RUN_ID="${RUN_ID:-taotern-pre-200m-stability-gate-$(date +%Y%m%d-%H%M%S)}"
5
+ JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
6
+ REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
7
+ PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
8
+ SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
9
+ DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
10
+ SFT_DATA_PATH="${SFT_DATA_PATH:-/home/student/Data/TaoData/sft.jsonl}"
11
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
12
+ SEQ_LEN="${SEQ_LEN:-512}"
13
+ BATCH_SIZE="${BATCH_SIZE:-8}"
14
+ TARGET_TOKENS="${TARGET_TOKENS:-20000000}"
15
+ MAX_TOKENS="${MAX_TOKENS:-50000000}"
16
+ MAX_RECORDS="${MAX_RECORDS:-120000}"
17
+ EVAL_BATCHES="${EVAL_BATCHES:-64}"
18
+ LEARNING_RATE="${LEARNING_RATE:-0.0008}"
19
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
20
+ TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-250}"
21
+ SFT_SANITY_SAMPLES="${SFT_SANITY_SAMPLES:-4}"
22
+ SFT_SANITY_STEPS="${SFT_SANITY_STEPS:-120}"
23
+ SFT_SANITY_LR="${SFT_SANITY_LR:-0.00005}"
24
+
25
+ JOB_COMMAND="REMOTE_REPO=$REMOTE_REPO PYTHON_BIN=$PYTHON_BIN SSM_REPO_PATH=$SSM_REPO_PATH DATA_PATH=$DATA_PATH SFT_DATA_PATH=$SFT_DATA_PATH TOKENIZER_PATH=$TOKENIZER_PATH SEQ_LEN=$SEQ_LEN BATCH_SIZE=$BATCH_SIZE TARGET_TOKENS=$TARGET_TOKENS MAX_TOKENS=$MAX_TOKENS MAX_RECORDS=$MAX_RECORDS EVAL_BATCHES=$EVAL_BATCHES LEARNING_RATE=$LEARNING_RATE WEIGHT_DECAY=$WEIGHT_DECAY TRAIN_LOG_EVERY=$TRAIN_LOG_EVERY SFT_SANITY_SAMPLES=$SFT_SANITY_SAMPLES SFT_SANITY_STEPS=$SFT_SANITY_STEPS SFT_SANITY_LR=$SFT_SANITY_LR bash scripts/remote/run_pre_200m_stability_gate.sh"
26
+
27
+ export RUN_ID JOB_ROOT JOB_COMMAND
28
+ export OUTPUT_DIR="${OUTPUT_DIR:-$JOB_ROOT/$RUN_ID/outputs}"
29
+ export CHECKPOINT_DIR="${CHECKPOINT_DIR:-$JOB_ROOT/$RUN_ID/checkpoints}"
30
+
31
+ bash scripts/remote/submit_detached_job.sh
32
+
code/TaoTrain/scripts/remote/submit_ssm_improvement_sweep.sh ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ RUN_ID="${RUN_ID:-ssm-improvement-sweep-$(date +%Y%m%d-%H%M%S)}"
5
+ DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
6
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
7
+ SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
8
+ PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
9
+ JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
10
+ REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
11
+
12
+ # This is intentionally below the eventual 200M scale. It is a model-selection
13
+ # sweep that can run unattended and resume completed benchmark cases.
14
+ MAX_TOKENS="${MAX_TOKENS:-100000000}"
15
+ MAX_RECORDS="${MAX_RECORDS:-150000}"
16
+ TRAIN_STEPS="${TRAIN_STEPS:-5000}"
17
+ EVAL_BATCHES="${EVAL_BATCHES:-96}"
18
+ BATCH_SIZES="${BATCH_SIZES:-32,64}"
19
+ SEQ_LEN="${SEQ_LEN:-512}"
20
+
21
+ JOB_COMMAND="PYTHONPATH=$REMOTE_REPO/src:$SSM_REPO_PATH $PYTHON_BIN scripts/benchmark_taonet_real_tokens.py \
22
+ --data-path $DATA_PATH \
23
+ --text-field text \
24
+ --tokenizer-type sentencepiece \
25
+ --tokenizer-path $TOKENIZER_PATH \
26
+ --max-records $MAX_RECORDS \
27
+ --max-tokens $MAX_TOKENS \
28
+ --eval-fraction 0.1 \
29
+ --architectures taonet,taonet_ssm,taonet_hybrid \
30
+ --batch-sizes $BATCH_SIZES \
31
+ --seq-len $SEQ_LEN \
32
+ --hidden-dim 256 \
33
+ --num-layers 4 \
34
+ --num-heads 4 \
35
+ --d-latent-kv 192 \
36
+ --hidden-dim-ff 1024 \
37
+ --ssm-core dplr \
38
+ --ssm-hidden-dims 16,32 \
39
+ --ssm-mixer-dims 128,256 \
40
+ --ssm-num-lanes-list 1,2 \
41
+ --ssm-lane-combine channel \
42
+ --ssm-lane-modes full,split \
43
+ --ssm-split-mixes none,hadamard \
44
+ --ssm-rank 1 \
45
+ --ssm-kernel-mode conv \
46
+ --no-ssm-finite-tail-correction \
47
+ --ssm-gate-types channel \
48
+ --hybrid-patterns attention_first,ssm_first,single_ssm_middle,single_ssm_late \
49
+ --dtype bf16 \
50
+ --device cuda \
51
+ --warmup 2 \
52
+ --repeats 3 \
53
+ --backward \
54
+ --train-steps $TRAIN_STEPS \
55
+ --learning-rate 0.0008 \
56
+ --weight-decay 0.01 \
57
+ --eval-batches $EVAL_BATCHES \
58
+ --ssm-local-shift \
59
+ --ssm-local-shift-per-channel \
60
+ --ssm-local-shift-init 0.1 \
61
+ --output-dir \"\$REPOBRIDGE_OUTPUT_DIR\" \
62
+ --resume-completed \
63
+ --incremental-output \
64
+ --save-case-checkpoints \
65
+ --checkpoint-dir \"\$TAOTERN_CHECKPOINT_DIR\""
66
+
67
+ export RUN_ID JOB_ROOT JOB_COMMAND
68
+ export OUTPUT_DIR="${OUTPUT_DIR:-$JOB_ROOT/$RUN_ID/outputs}"
69
+ export CHECKPOINT_DIR="${CHECKPOINT_DIR:-$JOB_ROOT/$RUN_ID/checkpoints}"
70
+
71
+ bash scripts/remote/submit_detached_job.sh
code/TaoTrain/scripts/summarize_taonet_benchmark_suite.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import csv
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+
10
+ def _as_float(value: str | None) -> float | None:
11
+ if value is None or value == "":
12
+ return None
13
+ try:
14
+ return float(value)
15
+ except ValueError:
16
+ return None
17
+
18
+
19
+ def _load_rows(root: Path) -> list[dict[str, Any]]:
20
+ rows: list[dict[str, Any]] = []
21
+ for csv_path in sorted(root.glob("*/taonet_real_token_benchmark.csv")):
22
+ variant = csv_path.parent.name
23
+ with csv_path.open("r", newline="", encoding="utf-8") as handle:
24
+ for row in csv.DictReader(handle):
25
+ row = dict(row)
26
+ row["variant"] = variant
27
+ rows.append(row)
28
+ return rows
29
+
30
+
31
+ def _best_forward_backward(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
32
+ candidates = [row for row in rows if row.get("mode") == "forward_backward"]
33
+ grouped: dict[str, list[dict[str, Any]]] = {}
34
+ for row in candidates:
35
+ grouped.setdefault(row["variant"], []).append(row)
36
+
37
+ best_rows = []
38
+ for variant, items in grouped.items():
39
+ items.sort(
40
+ key=lambda row: (
41
+ _as_float(row.get("eval_loss")) if _as_float(row.get("eval_loss")) is not None else float("inf"),
42
+ -(_as_float(row.get("eval_accuracy")) or 0.0),
43
+ )
44
+ )
45
+ best_rows.append(items[0])
46
+ best_rows.sort(
47
+ key=lambda row: (
48
+ _as_float(row.get("eval_loss")) if _as_float(row.get("eval_loss")) is not None else float("inf"),
49
+ -(_as_float(row.get("eval_accuracy")) or 0.0),
50
+ )
51
+ )
52
+ return best_rows
53
+
54
+
55
+ def _project(row: dict[str, Any]) -> dict[str, Any]:
56
+ keys = [
57
+ "variant",
58
+ "architecture",
59
+ "hybrid_pattern",
60
+ "batch_size",
61
+ "seq_len",
62
+ "total_params",
63
+ "ssm_core",
64
+ "ssm_hidden_dim",
65
+ "ssm_mixer_dim",
66
+ "ssm_num_lanes",
67
+ "ssm_lane_mode",
68
+ "ssm_split_mix",
69
+ "tokens_per_s_mean",
70
+ "eval_loss",
71
+ "eval_perplexity",
72
+ "eval_accuracy",
73
+ "train_final_loss",
74
+ "train_seconds",
75
+ "peak_reserved_mb",
76
+ "case_id",
77
+ "checkpoint_path",
78
+ ]
79
+ return {key: row.get(key, "") for key in keys}
80
+
81
+
82
+ def _write_markdown(summary: list[dict[str, Any]], path: Path) -> None:
83
+ headers = [
84
+ "variant",
85
+ "architecture",
86
+ "batch",
87
+ "params",
88
+ "eval_loss",
89
+ "eval_acc",
90
+ "tok/s",
91
+ "checkpoint",
92
+ ]
93
+ lines = [
94
+ "# TaoNet Benchmark Suite Summary",
95
+ "",
96
+ "| " + " | ".join(headers) + " |",
97
+ "| " + " | ".join(["---"] * len(headers)) + " |",
98
+ ]
99
+ for row in summary:
100
+ lines.append(
101
+ "| "
102
+ + " | ".join(
103
+ [
104
+ str(row["variant"]),
105
+ str(row["architecture"]),
106
+ str(row["batch_size"]),
107
+ str(row["total_params"]),
108
+ str(row["eval_loss"]),
109
+ str(row["eval_accuracy"]),
110
+ str(row["tokens_per_s_mean"]),
111
+ str(row["checkpoint_path"]),
112
+ ]
113
+ )
114
+ + " |"
115
+ )
116
+ path.write_text("\n".join(lines) + "\n", encoding="utf-8")
117
+
118
+
119
+ def main() -> None:
120
+ parser = argparse.ArgumentParser(description="Summarize a TaoNet benchmark suite output directory.")
121
+ parser.add_argument("--suite-dir", required=True, help="Directory containing one subdirectory per benchmark variant.")
122
+ parser.add_argument("--output-json", default="", help="Summary JSON path. Defaults to <suite-dir>/suite_summary.json.")
123
+ parser.add_argument("--output-md", default="", help="Summary Markdown path. Defaults to <suite-dir>/suite_summary.md.")
124
+ args = parser.parse_args()
125
+
126
+ suite_dir = Path(args.suite_dir)
127
+ rows = _load_rows(suite_dir)
128
+ summary = [_project(row) for row in _best_forward_backward(rows)]
129
+ json_path = Path(args.output_json) if args.output_json else suite_dir / "suite_summary.json"
130
+ md_path = Path(args.output_md) if args.output_md else suite_dir / "suite_summary.md"
131
+ json_path.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8")
132
+ _write_markdown(summary, md_path)
133
+ print(f"Wrote {json_path}")
134
+ print(f"Wrote {md_path}")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ main()
code/TaoTrain/src/taoTrain.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
code/TaoTrain/src/taoTrain.egg-info/entry_points.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [console_scripts]
2
+ train = taoTrain.cli:main
3
+ train-tokenizer = taoTrain.cli:train_tokenizer_command
4
+ tui-chat = taoTrain.inference.tui:main
code/TaoTrain/src/taoTrain/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TaoTrain: A clean, modular PyTorch LLM training framework.
3
+
4
+ For quick usage, see: https://github.com/your-org/taoTrain
5
+ """
6
+
7
+ __version__ = "0.1.0"
8
+ __author__ = "Felix"
code/TaoTrain/src/taoTrain/cli.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main CLI entry point."""
2
+
3
+ import sys
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ import click
7
+ import torch
8
+
9
+ from taoTrain.config import load_config, load_tokenizer_config, TrainingModeEnum, PretrainConfig, SFTConfig, RLConfig
10
+ from taoTrain.utils import set_seed, get_device
11
+ from taoTrain.core import BaseModel, create_model, create_datasets
12
+ from taoTrain.data import get_dataloader
13
+ from taoTrain.training import PretrainTrainer, SFTTrainer, RLTrainer
14
+ from taoTrain.benchmarks import BenchmarkRunner
15
+ from taoTrain.tokenizers import TokenizerTrainer
16
+
17
+
18
+ @click.group()
19
+ def main():
20
+ """TaoTrain: A clean, modular PyTorch LLM training framework."""
21
+ pass
22
+
23
+
24
+ @main.command()
25
+ @click.option(
26
+ "--config",
27
+ type=click.Path(exists=True),
28
+ required=True,
29
+ help="Path to training config file (YAML or JSON)",
30
+ )
31
+ def pretrain(config: str):
32
+ """Pretrain a language model."""
33
+ _train_command(config, TrainingModeEnum.PRETRAIN)
34
+
35
+
36
+ @main.command()
37
+ @click.option(
38
+ "--config",
39
+ type=click.Path(exists=True),
40
+ required=True,
41
+ help="Path to training config file (YAML or JSON)",
42
+ )
43
+ def sft(config: str):
44
+ """Supervised fine-tune a language model."""
45
+ _train_command(config, TrainingModeEnum.SFT)
46
+
47
+
48
+ @main.command()
49
+ @click.option(
50
+ "--config",
51
+ type=click.Path(exists=True),
52
+ required=True,
53
+ help="Path to training config file (YAML or JSON)",
54
+ )
55
+ def rl(config: str):
56
+ """Train with reinforcement learning."""
57
+ _train_command(config, TrainingModeEnum.RL)
58
+
59
+
60
+ @main.command()
61
+ @click.option(
62
+ "--config",
63
+ type=click.Path(exists=True),
64
+ required=True,
65
+ help="Path to tokenizer config file (YAML or JSON)",
66
+ )
67
+ def train_tokenizer(config: str):
68
+ """Train a SentencePiece tokenizer from a YAML/JSON config file."""
69
+ try:
70
+ click.echo("🚀 TaoTrain Tokenizer Trainer")
71
+ click.echo(f"{'=' * 50}")
72
+
73
+ # Load tokenizer config
74
+ click.echo(f"Loading config from {config}...")
75
+ tokenizer_config = load_tokenizer_config(config)
76
+
77
+ # Train tokenizer from config
78
+ result = TokenizerTrainer.train_from_config(tokenizer_config)
79
+
80
+ # Display results
81
+ click.echo(f"\n{'=' * 50}")
82
+ click.echo("✅ Tokenizer Training Complete!")
83
+ click.echo(f"\n📊 Configuration:")
84
+ click.echo(f" - Input file: {tokenizer_config.jsonl_path}")
85
+ click.echo(f" - Samples: {tokenizer_config.max_samples or 'all'}")
86
+ click.echo(f" - Output dir: {result['output_dir']}")
87
+ click.echo(f" - Vocab size: {result['vocab_size']}")
88
+ click.echo(f" - Model type: {result['model_type']}")
89
+
90
+ if tokenizer_config.special_tokens:
91
+ click.echo(f" - Special tokens: {tokenizer_config.special_tokens}")
92
+
93
+ click.echo(f"\n📁 Generated Files:")
94
+ click.echo(f" - Model: {result['model_file']}")
95
+ click.echo(f" - Vocab: {result['vocab_file']}")
96
+
97
+ click.echo(f"\n📝 Next Steps:")
98
+ click.echo(f" 1. Use this tokenizer in your pretraining config:")
99
+ click.echo(f" dataset:")
100
+ click.echo(f" local: true")
101
+ click.echo(f" jsonl_path: {tokenizer_config.jsonl_path}")
102
+ click.echo(f" tokenizer_path: {result['model_file']}")
103
+ click.echo(f"")
104
+ click.echo(f" 2. Run pretraining with:")
105
+ click.echo(f" train pretrain --config your_config.yaml")
106
+
107
+ except ImportError as e:
108
+ click.echo(f"❌ Error: {e}", err=True)
109
+ sys.exit(1)
110
+ except FileNotFoundError as e:
111
+ click.echo(f"❌ File Error: {e}", err=True)
112
+ sys.exit(1)
113
+ except ValueError as e:
114
+ click.echo(f"❌ Validation Error: {e}", err=True)
115
+ sys.exit(1)
116
+ except Exception as e:
117
+ click.echo(f"❌ Unexpected Error: {e}", err=True)
118
+ click.echo(f" Please report this issue.", err=True)
119
+ sys.exit(1)
120
+
121
+
122
+ @click.command()
123
+ @click.option(
124
+ "--jsonl-path",
125
+ type=click.Path(exists=True),
126
+ required=True,
127
+ help="Path to JSONL file containing training data",
128
+ )
129
+ @click.option(
130
+ "--output-dir",
131
+ type=click.Path(),
132
+ default="tokenizers",
133
+ help="Directory to save tokenizer files",
134
+ )
135
+ @click.option(
136
+ "--vocab-size",
137
+ type=int,
138
+ default=50000,
139
+ help="Vocabulary size for the tokenizer",
140
+ )
141
+ @click.option(
142
+ "--model-type",
143
+ type=click.Choice(["unigram", "bpe", "char", "word"]),
144
+ default="unigram",
145
+ help="SentencePiece model type",
146
+ )
147
+ @click.option(
148
+ "--character-coverage",
149
+ type=float,
150
+ default=0.9995,
151
+ help="Character coverage for SentencePiece",
152
+ )
153
+ @click.option(
154
+ "--tokenizer-prefix",
155
+ type=str,
156
+ default=None,
157
+ help="Prefix for tokenizer output files (default: model_type)",
158
+ )
159
+ def train_tokenizer_command(
160
+ jsonl_path: str,
161
+ output_dir: str,
162
+ vocab_size: int,
163
+ model_type: str,
164
+ character_coverage: float,
165
+ tokenizer_prefix: Optional[str],
166
+ ):
167
+ """Train a SentencePiece tokenizer from JSONL data."""
168
+ try:
169
+ click.echo("🚀 TaoTrain Tokenizer Trainer")
170
+ click.echo(f"{'=' * 50}")
171
+
172
+ # Train tokenizer
173
+ result = TokenizerTrainer.train_sentencepiece(
174
+ jsonl_path=jsonl_path,
175
+ output_dir=output_dir,
176
+ vocab_size=vocab_size,
177
+ model_type=model_type,
178
+ character_coverage=character_coverage,
179
+ tokenizer_prefix=tokenizer_prefix,
180
+ )
181
+
182
+ # Display results
183
+ click.echo(f"\n{'=' * 50}")
184
+ click.echo("✅ Tokenizer Training Complete!")
185
+ click.echo(f"\n📊 Configuration:")
186
+ click.echo(f" - Input file: {jsonl_path}")
187
+ click.echo(f" - Output dir: {result['output_dir']}")
188
+ click.echo(f" - Vocab size: {result['vocab_size']}")
189
+ click.echo(f" - Model type: {result['model_type']}")
190
+
191
+ click.echo(f"\n📁 Generated Files:")
192
+ click.echo(f" - Model: {result['model_file']}")
193
+ click.echo(f" - Vocab: {result['vocab_file']}")
194
+
195
+ click.echo(f"\n📝 Next Steps:")
196
+ click.echo(f" 1. Use this tokenizer in your pretraining config:")
197
+ click.echo(f" dataset:")
198
+ click.echo(f" local: true")
199
+ click.echo(f" jsonl_path: {jsonl_path}")
200
+ click.echo(f" tokenizer_path: {result['model_file']}")
201
+ click.echo(f"")
202
+ click.echo(f" 2. Run pretraining with:")
203
+ click.echo(f" train pretrain --config your_config.yaml")
204
+
205
+ except ImportError as e:
206
+ click.echo(f"❌ Error: {e}", err=True)
207
+ sys.exit(1)
208
+ except FileNotFoundError as e:
209
+ click.echo(f"❌ File Error: {e}", err=True)
210
+ sys.exit(1)
211
+ except ValueError as e:
212
+ click.echo(f"❌ Validation Error: {e}", err=True)
213
+ sys.exit(1)
214
+ except Exception as e:
215
+ click.echo(f"❌ Unexpected Error: {e}", err=True)
216
+ click.echo(f" Please report this issue.", err=True)
217
+ sys.exit(1)
218
+
219
+
220
+ # Keep legacy CLI command as train-tokenizer-legacy for backward compatibility
221
+ main.add_command(train_tokenizer_command, name="train-tokenizer-legacy")
222
+
223
+
224
+ def _train_command(config_path: str, mode: TrainingModeEnum):
225
+ """Internal training command."""
226
+ try:
227
+ # Load config
228
+ click.echo(f"Loading config from {config_path}...")
229
+ train_config = load_config(config_path, mode)
230
+
231
+ # Set seed
232
+ set_seed(train_config.seed)
233
+
234
+ # Get device
235
+ device = get_device(train_config.device)
236
+ click.echo(f"Using device: {device}")
237
+
238
+ # Create model
239
+ click.echo("Creating model...")
240
+ model = create_model(train_config, device)
241
+ total_params, trainable_params = _count_params(model)
242
+ click.echo(f" - Total parameters: {total_params:,}")
243
+ click.echo(f" - Trainable parameters: {trainable_params:,}")
244
+
245
+ # Load pretrained checkpoint if provided (for SFT/RL)
246
+ if train_config.checkpoint_path:
247
+ click.echo(f"Loading pretrained checkpoint from {train_config.checkpoint_path}...")
248
+ from taoTrain.checkpointing.checkpoint import CheckpointManager
249
+ checkpoint_manager = CheckpointManager(train_config.checkpoint_dir)
250
+ checkpoint = checkpoint_manager.load(train_config.checkpoint_path, device=device)
251
+
252
+ # CheckpointManager.load() normalizes format and ensures 'model_state' key exists
253
+ if "model_state" in checkpoint:
254
+ model.load_state_dict(checkpoint["model_state"], strict=False)
255
+ click.echo(" ✓ Checkpoint loaded successfully")
256
+ else:
257
+ raise KeyError(f"Invalid checkpoint format: 'model_state' key not found. "
258
+ f"Available keys: {list(checkpoint.keys())}")
259
+
260
+
261
+
262
+ # Create datasets
263
+ click.echo("Loading datasets...")
264
+ train_dataset, val_dataset = create_datasets(train_config)
265
+ click.echo(f" - Train samples: {len(train_dataset)}")
266
+ if val_dataset:
267
+ click.echo(f" - Val samples: {len(val_dataset)}")
268
+
269
+ # Select trainer
270
+ if mode == TrainingModeEnum.PRETRAIN:
271
+ trainer_class = PretrainTrainer
272
+ elif mode == TrainingModeEnum.SFT:
273
+ trainer_class = SFTTrainer
274
+ elif mode == TrainingModeEnum.RL:
275
+ trainer_class = RLTrainer
276
+ else:
277
+ raise ValueError(f"Unknown training mode: {mode}")
278
+
279
+ # Create trainer
280
+ click.echo("Setting up trainer...")
281
+ trainer = trainer_class(
282
+ model=model,
283
+ train_dataset=train_dataset,
284
+ val_dataset=val_dataset,
285
+ config=train_config,
286
+ device=device,
287
+ )
288
+
289
+ # Training loop
290
+ click.echo("\nStarting training...\n")
291
+ for epoch in range(train_config.num_epochs):
292
+ if train_config.max_steps and trainer.global_step >= train_config.max_steps:
293
+ break
294
+
295
+ epoch_metrics = trainer.train_epoch()
296
+ click.echo(f"\nEpoch {epoch + 1} complete")
297
+ click.echo(f" - Loss: {epoch_metrics.get('loss', 'N/A')}")
298
+ click.echo(f" - Learning rate: {epoch_metrics.get('lr', 'N/A')}")
299
+
300
+ # Final checkpoint
301
+ final_path = Path(train_config.checkpoint_dir) / "final_model.pt"
302
+ trainer.save_checkpoint(final_path)
303
+ click.echo(f"\nTraining complete! Final model saved to {final_path}")
304
+
305
+ # Log finish
306
+ trainer.logger.finish()
307
+
308
+ except Exception as e:
309
+ click.echo(f"Error during training: {e}", err=True)
310
+ sys.exit(1)
311
+
312
+
313
+ @main.command()
314
+ @click.option(
315
+ "--model",
316
+ type=click.Path(exists=True),
317
+ required=True,
318
+ help="Path to model checkpoint",
319
+ )
320
+ @click.option(
321
+ "--benchmark-type",
322
+ type=click.Choice(["all", "perplexity", "throughput", "memory"]),
323
+ default="all",
324
+ help="Type of benchmark to run",
325
+ )
326
+ @click.option(
327
+ "--batch-size",
328
+ type=int,
329
+ default=32,
330
+ help="Batch size for benchmarking",
331
+ )
332
+ @click.option(
333
+ "--seq-length",
334
+ type=int,
335
+ default=1024,
336
+ help="Sequence length for benchmarking",
337
+ )
338
+ def benchmark(model: str, benchmark_type: str, batch_size: int, seq_length: int):
339
+ """Benchmark a trained model."""
340
+ try:
341
+ click.echo(f"Loading model from {model}...")
342
+ device = get_device("cuda")
343
+ runner = BenchmarkRunner.load_from_checkpoint(model, device=device)
344
+
345
+ click.echo("Running benchmarks...\n")
346
+
347
+ if benchmark_type == "throughput" or benchmark_type == "all":
348
+ click.echo("Throughput benchmark:")
349
+ results = runner.benchmark_throughput(batch_size, seq_length)
350
+ for key, val in results.items():
351
+ click.echo(f" {key}: {val:.2f}")
352
+
353
+ if benchmark_type == "memory" or benchmark_type == "all":
354
+ click.echo("\nMemory benchmark:")
355
+ results = runner.benchmark_memory()
356
+ for key, val in results.items():
357
+ click.echo(f" {key}: {val:.2f}")
358
+
359
+ click.echo("\nBenchmarking complete!")
360
+
361
+ except Exception as e:
362
+ click.echo(f"Error during benchmarking: {e}", err=True)
363
+ sys.exit(1)
364
+
365
+
366
+ @main.command()
367
+ @click.option("--repo", type=str, default=".aim", help="AimStack repository path")
368
+ def view_logs(repo: str):
369
+ """View training logs with AimStack."""
370
+ try:
371
+ import subprocess
372
+ click.echo(f"Opening AimStack dashboard for repo: {repo}")
373
+ subprocess.run(["aim", "up", "--repo", repo])
374
+ except FileNotFoundError:
375
+ click.echo("Error: 'aim' command not found. Install with: pip install aim", err=True)
376
+ sys.exit(1)
377
+
378
+
379
+ def _count_params(model: BaseModel) -> tuple[int, int]:
380
+ """Count model parameters."""
381
+ total = sum(p.numel() for p in model.parameters())
382
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
383
+ return total, trainable
384
+
385
+
386
+ if __name__ == "__main__":
387
+ main()
code/TaoTrain/src/taoTrain/config.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic configuration schemas for TaoTrain."""
2
+
3
+ from enum import Enum
4
+ from typing import Optional, Literal
5
+ from pathlib import Path
6
+ import json
7
+ from pydantic import BaseModel as PydanticBaseModel, Field, validator
8
+ import yaml
9
+
10
+
11
+ # ============================================================================
12
+ # Enums
13
+ # ============================================================================
14
+
15
+
16
+ class DataTypeEnum(str, Enum):
17
+ """Data types for training."""
18
+ FLOAT32 = "float32"
19
+ FLOAT16 = "float16"
20
+ BFLOAT16 = "bfloat16"
21
+
22
+
23
+ class OptimizerEnum(str, Enum):
24
+ """Supported optimizers."""
25
+ ADAM = "adam"
26
+ ADAMW = "adamw"
27
+ SGD = "sgd"
28
+ HYBRID_MUON_ADAMW = "hybrid_muon_adamw"
29
+
30
+
31
+ class ModelArchitectureEnum(str, Enum):
32
+ """Built-in model architectures."""
33
+ TRANSFORMER = "transformer"
34
+ TAONET = "taonet"
35
+ TAONET_SSM = "taonet_ssm"
36
+ TAONET_HYBRID = "taonet_hybrid"
37
+
38
+
39
+ class SchedulerEnum(str, Enum):
40
+ """Supported learning rate schedulers."""
41
+ LINEAR_WARMUP = "linearWarmup"
42
+ COSINE_WARMUP = "cosineWarmup"
43
+ CONSTANT = "constant"
44
+
45
+
46
+ class RLMethodEnum(str, Enum):
47
+ """Supported RL training methods."""
48
+ PPO = "ppo"
49
+ DPO = "dpo"
50
+
51
+
52
+ class TrainingModeEnum(str, Enum):
53
+ """Training stages."""
54
+ PRETRAIN = "pretrain"
55
+ SFT = "sft"
56
+ RL = "rl"
57
+
58
+
59
+ # ============================================================================
60
+ # Base Configs
61
+ # ============================================================================
62
+
63
+
64
+ class BaseConfig(PydanticBaseModel):
65
+ """Base Pydantic model with utility methods."""
66
+
67
+ class Config:
68
+ """Pydantic config."""
69
+ arbitrary_types_allowed = True
70
+
71
+ def to_dict(self) -> dict:
72
+ """Convert to dictionary."""
73
+ data = self.model_dump(mode='json') # Enums -> strings
74
+ return data
75
+
76
+ def to_json_str(self) -> str:
77
+ """Convert to JSON string."""
78
+ return json.dumps(self.to_dict(), indent=2)
79
+
80
+ def save_yaml(self, path: str | Path) -> None:
81
+ """Save config to YAML file."""
82
+ path = Path(path)
83
+ path.parent.mkdir(parents=True, exist_ok=True)
84
+ with open(path, 'w') as f:
85
+ yaml.dump(self.to_dict(), f, default_flow_style=False, sort_keys=False)
86
+
87
+ def save_json(self, path: str | Path) -> None:
88
+ """Save config to JSON file."""
89
+ path = Path(path)
90
+ path.parent.mkdir(parents=True, exist_ok=True)
91
+ with open(path, 'w') as f:
92
+ f.write(self.to_json_str())
93
+
94
+ @classmethod
95
+ def load_yaml(cls, path: str | Path) -> "BaseConfig":
96
+ """Load config from YAML file."""
97
+ with open(path) as f:
98
+ data = yaml.safe_load(f)
99
+ return cls(**data)
100
+
101
+ @classmethod
102
+ def load_json(cls, path: str | Path) -> "BaseConfig":
103
+ """Load config from JSON file."""
104
+ with open(path) as f:
105
+ data = json.load(f)
106
+ return cls(**data)
107
+
108
+
109
+ # ============================================================================
110
+ # Model Config
111
+ # ============================================================================
112
+
113
+
114
+ class ModelConfig(BaseConfig):
115
+ """Configuration for model architecture."""
116
+
117
+ architecture_type: ModelArchitectureEnum = Field(
118
+ default=ModelArchitectureEnum.TRANSFORMER,
119
+ description="Type of model architecture"
120
+ )
121
+
122
+ # Transformer-specific
123
+ vocab_size: int = Field(default=50257, description="Vocabulary size")
124
+ hidden_dim: int = Field(default=768, description="Hidden dimension")
125
+ num_layers: int = Field(default=12, description="Number of transformer blocks")
126
+ num_heads: int = Field(default=12, description="Number of attention heads")
127
+ head_dim: Optional[int] = Field(
128
+ default=None,
129
+ description="Head dimension (defaults to hidden_dim // num_heads)"
130
+ )
131
+ intermediate_dim: Optional[int] = Field(
132
+ default=None,
133
+ description="FFN intermediate dimension (defaults to 4 * hidden_dim)"
134
+ )
135
+ dropout: float = Field(default=0.1, description="Dropout rate")
136
+ max_seq_length: int = Field(default=2048, description="Maximum sequence length")
137
+
138
+ # TaoNet (DeepSeek MLA) specific
139
+ d_latent_kv: Optional[int] = Field(
140
+ default=None,
141
+ description="KV compression dimension for MLA (defaults to 3/4 * hidden_dim). Only used for taonet architecture."
142
+ )
143
+ d_rope: Optional[int] = Field(
144
+ default=None,
145
+ description="RoPE dimension per head (defaults to hidden_dim // num_heads). Only used for taonet architecture."
146
+ )
147
+ gqa_groups: int = Field(
148
+ default=1,
149
+ description="Grouped Query Attention groups (1 = standard MLA, >1 = GQA). Only used for taonet architecture."
150
+ )
151
+ hidden_dim_ff: Optional[int] = Field(
152
+ default=None,
153
+ description="Feed-forward intermediate dimension (defaults to 4 * hidden_dim)."
154
+ )
155
+ use_factorized_embedding: bool = Field(
156
+ default=False,
157
+ description="Use low-rank factorized embedding instead of standard embedding (reduces params). Only for taonet."
158
+ )
159
+ d_embed_rank: int = Field(
160
+ default=96,
161
+ description="Rank dimension for factorized embedding. Only used if use_factorized_embedding=True."
162
+ )
163
+
164
+ # YaRN (Yet another RoPE eXtension) for context length extension
165
+ rope_scale: float = Field(
166
+ default=40.0,
167
+ description="Base RoPE scale factor (default: 40.0). Controls position frequency base."
168
+ )
169
+ yarn_enabled: bool = Field(
170
+ default=False,
171
+ description="Enable YaRN (Yet another RoPE eXtension) for context length interpolation."
172
+ )
173
+ yarn_alpha: float = Field(
174
+ default=1.0,
175
+ description="YaRN interpolation smoothness (1.0=smooth, <1.0=aggressive, >1.0=conservative). Only used if yarn_enabled=True."
176
+ )
177
+
178
+ # TaoNet-SSM specific: SSM mixer replacing MLA attention
179
+ ssm_core: Literal["gamma_s4", "dplr"] = Field(
180
+ default="gamma_s4",
181
+ description="SSM core used by taonet_ssm. Use dplr for the ternary-aware DPLR SSM."
182
+ )
183
+ ssm_hidden_dim: Optional[int] = Field(
184
+ default=None,
185
+ description="SSM hidden/state dimension for taonet_ssm. Defaults to d_latent_kv or hidden_dim."
186
+ )
187
+ ssm_mixer_dim: Optional[int] = Field(
188
+ default=None,
189
+ description="Channel dimension processed by the SSM mixer. Defaults to hidden_dim; smaller values use an input/output projection bottleneck."
190
+ )
191
+ ssm_num_lanes: int = Field(
192
+ default=1,
193
+ description="Number of independent SSM lanes inside each SSM mixer. Multiple lanes add SSM capacity with cheap elementwise combination."
194
+ )
195
+ ssm_lane_combine: Literal["mean", "channel"] = Field(
196
+ default="mean",
197
+ description="How to combine multiple SSM lanes. Channel uses learned per-lane/per-channel elementwise weights."
198
+ )
199
+ ssm_lane_mode: Literal["full", "split"] = Field(
200
+ default="full",
201
+ description="Whether each SSM lane processes the full mixer dimension or a disjoint split of the mixer channels."
202
+ )
203
+ ssm_split_mix: Literal["none", "hadamard"] = Field(
204
+ default="none",
205
+ description="Optional ternary-friendly cross-lane mixer for split SSM lanes."
206
+ )
207
+ ssm_rank: int = Field(
208
+ default=1,
209
+ description="Low-rank correction rank for ssm_core=dplr."
210
+ )
211
+ ssm_max_low_rank_scale: float = Field(
212
+ default=0.1,
213
+ description="Maximum low-rank correction scale for ssm_core=dplr."
214
+ )
215
+ ssm_finite_tail_correction: bool = Field(
216
+ default=True,
217
+ description="Enable exact finite-length tail correction for ssm_core=dplr. Disable for the faster approximate DPLR path."
218
+ )
219
+ ssm_discretization: Literal["bilinear", "zoh", "euler"] = Field(
220
+ default="bilinear",
221
+ description="Discretization used by the Gamma SSM mixer."
222
+ )
223
+ ssm_kernel_mode: Literal["auto", "recurrent", "conv", "conv_transfer"] = Field(
224
+ default="auto",
225
+ description="Gamma SSM execution path. Use auto/conv for full-sequence GPU training, conv_transfer to materialize frequency transfers, recurrent for step-wise tests."
226
+ )
227
+ ssm_kernel_threshold: int = Field(
228
+ default=64,
229
+ description="Minimum sequence length for auto mode to use the convolutional Gamma SSM path."
230
+ )
231
+ ssm_dt_min: float = Field(default=1e-3, description="Minimum learned SSM timestep.")
232
+ ssm_dt_max: float = Field(default=1e-1, description="Maximum learned SSM timestep.")
233
+ ssm_dt_init: float = Field(default=1e-2, description="Initial learned SSM timestep.")
234
+ ssm_use_d: bool = Field(default=True, description="Enable direct skip term D in the Gamma SSM.")
235
+ ssm_activation: Literal["gelu", "silu", "identity", "linear"] = Field(
236
+ default="gelu",
237
+ description="Activation applied to the Gamma SSM branch output."
238
+ )
239
+ ssm_gate: bool = Field(default=True, description="Enable output gate on the Gamma SSM branch.")
240
+ ssm_input_gate: bool = Field(default=True, description="Enable input gate before the Gamma SSM.")
241
+ ssm_gate_type: Literal["dense", "channel"] = Field(
242
+ default="dense",
243
+ description="Gate implementation for enabled SSM input/output gates. Channel gates are elementwise and ternary-friendly."
244
+ )
245
+ ssm_use_padding_mask: bool = Field(
246
+ default=False,
247
+ description="Apply dataset padding masks inside the SSM. Disabled by default so training can use the convolutional path."
248
+ )
249
+ ssm_layer_scale_init: float = Field(
250
+ default=0.1,
251
+ description="Initial layer-scale multiplier for the Gamma SSM branch."
252
+ )
253
+ ssm_branch_rms_norm: bool = Field(
254
+ default=False,
255
+ description="Normalize the SSM residual branch to unit RMS before layer-scale. Useful for stabilizing deep SSM/hybrid runs."
256
+ )
257
+ ssm_branch_rms_eps: float = Field(
258
+ default=1e-6,
259
+ description="Numerical epsilon for optional SSM branch RMS normalization."
260
+ )
261
+ ssm_branch_clip_value: Optional[float] = Field(
262
+ default=None,
263
+ description="Optional symmetric clamp applied to the SSM residual branch after layer-scale. None disables clamping."
264
+ )
265
+ block_residual_rms_norm: bool = Field(
266
+ default=False,
267
+ description="Normalize the residual stream RMS after block residual additions. Intended for stabilizing deep SSM/hybrid experiments."
268
+ )
269
+ block_residual_rms_target: float = Field(
270
+ default=1.0,
271
+ description="Target per-token RMS when block_residual_rms_norm is enabled."
272
+ )
273
+ block_residual_rms_cap: Optional[float] = Field(
274
+ default=None,
275
+ description="Optional per-token RMS cap for the residual stream. Unlike block_residual_rms_norm, this only scales down tokens whose RMS exceeds the cap."
276
+ )
277
+ block_residual_rms_eps: float = Field(
278
+ default=1e-6,
279
+ description="Numerical epsilon for optional block residual RMS normalization."
280
+ )
281
+ ssm_local_shift: bool = Field(
282
+ default=False,
283
+ description="Add a cheap one-token causal shift/register branch to the taonet_ssm mixer."
284
+ )
285
+ ssm_local_shift_init: float = Field(
286
+ default=0.1,
287
+ description="Initial scalar weight for the optional one-token local shift/register branch."
288
+ )
289
+ ssm_local_shift_per_channel: bool = Field(
290
+ default=False,
291
+ description="Use one learned local-shift gain per model channel instead of one scalar."
292
+ )
293
+ hybrid_pattern: Literal["attention_first", "ssm_first", "single_ssm_middle", "single_ssm_late"] = Field(
294
+ default="attention_first",
295
+ description="Layer pattern for taonet_hybrid when hybrid_ssm_layers is not set."
296
+ )
297
+ hybrid_ssm_layers: Optional[str] = Field(
298
+ default=None,
299
+ description="Optional comma-separated 0-based layer indices that should use SSM blocks in taonet_hybrid."
300
+ )
301
+
302
+ # Initializations
303
+ init_std: float = Field(default=0.02, description="Weight initialization standard deviation")
304
+
305
+ @validator("head_dim", always=True)
306
+ def validate_head_dim(cls, v, values):
307
+ """Validate head dimension."""
308
+ if v is None and 'hidden_dim' in values:
309
+ return values['hidden_dim'] // values.get('num_heads', 12)
310
+ return v
311
+
312
+ @validator("intermediate_dim", always=True)
313
+ def validate_intermediate_dim(cls, v, values):
314
+ """Validate intermediate dimension."""
315
+ if v is None and 'hidden_dim' in values:
316
+ return 4 * values['hidden_dim']
317
+ return v
318
+
319
+
320
+ # ============================================================================
321
+ # Dataset Config
322
+ # ============================================================================
323
+
324
+
325
+ class DatasetConfig(BaseConfig):
326
+ """Configuration for dataset loading."""
327
+
328
+ # Local vs HuggingFace dataset selection
329
+ local: bool = Field(default=False, description="Use local JSONL dataset instead of HuggingFace")
330
+
331
+ # HuggingFace dataset fields
332
+ dataset_name: Optional[str] = Field(default=None, description="HuggingFace dataset name (e.g., 'wikitext', 'openwebtext')")
333
+ split: str = Field(default="train", description="Dataset split to use")
334
+ config: Optional[str] = Field(default=None, description="Dataset config if multi-config (e.g., 'wikitext-103')")
335
+
336
+ # Local JSONL dataset fields
337
+ jsonl_path: Optional[str] = Field(default=None, description="Path to local JSONL dataset file")
338
+ text_field: str = Field(default="text", description="Name of text field in JSONL")
339
+
340
+ # Text column name varies by dataset
341
+ text_column: str = Field(default="text", description="Name of text column in dataset")
342
+
343
+ # Preprocessing
344
+ max_samples: Optional[int] = Field(
345
+ default=None,
346
+ description="Limit dataset to N samples (useful for debugging)"
347
+ )
348
+ cache_dir: str = Field(default=".cache/datasets", description="HuggingFace cache directory")
349
+
350
+ # For SFT/RL datasets with instruction-response format
351
+ instruction_column: Optional[str] = Field(default=None, description="Instruction column for SFT")
352
+ response_column: Optional[str] = Field(default=None, description="Response column for SFT")
353
+ prompt_column: Optional[str] = Field(default=None, description="Prompt column for RL")
354
+
355
+ # Instruction template
356
+ instruction_template: Optional[str] = Field(
357
+ default=None,
358
+ description="Template for combining instruction and response. E.g., '{instruction}\\n{response}'"
359
+ )
360
+
361
+ # Tokenizer configuration
362
+ tokenizer_type: Optional[str] = Field(
363
+ default=None,
364
+ description="Tokenizer type: 'huggingface' or 'sentencepiece'. If None, defaults based on tokenizer_path."
365
+ )
366
+ tokenizer_path: Optional[str] = Field(
367
+ default=None,
368
+ description="Path to saved tokenizer (for SentencePiece: .model file, for HuggingFace: model name or local path)"
369
+ )
370
+
371
+ # Chunked loading for large JSONL files
372
+ enable_streaming: bool = Field(
373
+ default=True,
374
+ description="Enable streaming/chunked loading for large JSONL files to reduce memory usage"
375
+ )
376
+ chunk_size_gb: float = Field(
377
+ default=5.0,
378
+ description="Approximate chunk size in GB (ignored if samples_per_chunk is set)"
379
+ )
380
+ samples_per_chunk: Optional[int] = Field(
381
+ default=1000,
382
+ description="Number of samples per chunk (takes precedence over chunk_size_gb). Default: 1000 samples"
383
+ )
384
+
385
+ # Chunk caching
386
+ enable_chunk_metadata_cache: bool = Field(
387
+ default=True,
388
+ description="Enable caching of chunk metadata (file scan results) to avoid re-scanning large JSONL files"
389
+ )
390
+ enable_chunk_data_cache: bool = Field(
391
+ default=False,
392
+ description="Enable caching of actual chunk data as separate files for faster loading (uses more disk space)"
393
+ )
394
+ chunk_cache_dir: str = Field(
395
+ default=".cache/chunks",
396
+ description="Directory to store chunk metadata and data cache files"
397
+ )
398
+
399
+ # Tokenization parallelization
400
+ tokenizer_threads: int = Field(
401
+ default=1,
402
+ description="Number of background threads for tokenization (1-32 recommended). Higher values speed up tokenization but increase memory usage."
403
+ )
404
+
405
+ @validator('jsonl_path', always=True)
406
+ def validate_dataset_source(cls, v, values):
407
+ """Validate that either local JSONL or HuggingFace dataset is specified."""
408
+ local = values.get('local', False)
409
+ dataset_name = values.get('dataset_name')
410
+
411
+ if local and not v:
412
+ raise ValueError("jsonl_path must be provided when local=True")
413
+ if not local and not dataset_name:
414
+ raise ValueError("dataset_name must be provided when local=False (HuggingFace dataset)")
415
+
416
+ return v
417
+
418
+ @validator('tokenizer_threads')
419
+ def validate_tokenizer_threads(cls, v):
420
+ """Validate tokenizer_threads is a positive integer."""
421
+ if v < 1:
422
+ raise ValueError("tokenizer_threads must be at least 1")
423
+ if v > 128:
424
+ raise ValueError("tokenizer_threads should not exceed 128 (recommended: 1-32)")
425
+ return v
426
+
427
+
428
+ # ============================================================================
429
+ # Tokenizer Config
430
+ # ============================================================================
431
+
432
+
433
+ class TokenizerConfig(BaseConfig):
434
+ """Configuration for tokenizer training."""
435
+
436
+ # Dataset source
437
+ jsonl_path: str = Field(description="Path to JSONL file containing training data")
438
+ text_field: str = Field(default="text", description="Field name in JSONL for text data")
439
+
440
+ # Training configuration
441
+ vocab_size: int = Field(default=50000, description="Vocabulary size")
442
+ model_type: str = Field(default="unigram", description="SentencePiece model type (unigram, bpe, char, word)")
443
+ character_coverage: float = Field(
444
+ default=0.9995,
445
+ description="Character coverage for SentencePiece training"
446
+ )
447
+ output_dir: str = Field(default="tokenizers", description="Directory to save trained tokenizer")
448
+ tokenizer_prefix: Optional[str] = Field(
449
+ default=None,
450
+ description="Prefix for tokenizer output files (default: model_type)"
451
+ )
452
+
453
+ # SentencePiece token IDs
454
+ unk_id: int = Field(default=0, description="Unknown token ID")
455
+ bos_id: int = Field(default=1, description="Beginning of sentence token ID")
456
+ eos_id: int = Field(default=2, description="End of sentence token ID")
457
+ pad_id: int = Field(default=3, description="Padding token ID")
458
+
459
+ # Custom special tokens - add custom tokens like <PAD>, <EOS>, <BOS>, <UNK>, <think>, <user>, <assistant>, <image>
460
+ special_tokens: Optional[dict[str, int]] = Field(
461
+ default=None,
462
+ description="Custom special tokens mapping: {token: id}. Example: {'<think>': 4, '<user>': 5, '<assistant>': 6, '<image>': 7}"
463
+ )
464
+
465
+ # Data sampling
466
+ max_samples: Optional[int] = Field(
467
+ default=None,
468
+ description="Limit training to first N samples from JSONL (useful for quick testing)"
469
+ )
470
+
471
+ # Tokenizer metadata
472
+ tokenizer_name: Optional[str] = Field(
473
+ default=None,
474
+ description="Optional name for the tokenizer"
475
+ )
476
+
477
+
478
+ # ============================================================================
479
+ # Training Config
480
+ # ============================================================================
481
+
482
+
483
+ class OptimizerConfig(BaseConfig):
484
+ """Optimizer configuration."""
485
+
486
+ optimizer_type: OptimizerEnum = Field(default=OptimizerEnum.ADAMW, description="Optimizer type")
487
+ learning_rate: float = Field(default=1e-4, description="Peak learning rate (for Muon 2D weights)")
488
+ adamw_lr: Optional[float] = Field(
489
+ default=None,
490
+ description="Learning rate for AdamW (1D parameters). If None, defaults to learning_rate / 10. Used in hybrid_muon_adamw optimizer."
491
+ )
492
+ weight_decay: float = Field(default=1e-2, description="Weight decay (L2 regularization)")
493
+ betas: tuple[float, float] = Field(default=(0.9, 0.999), description="Adam betas")
494
+ eps: float = Field(default=1e-8, description="Optimizer epsilon")
495
+
496
+ @validator('adamw_lr', always=True)
497
+ def set_default_adamw_lr(cls, v, values):
498
+ """Set default adamw_lr as 1/10 of learning_rate if not specified."""
499
+ if v is None and 'learning_rate' in values:
500
+ return values['learning_rate'] / 10
501
+ return v
502
+
503
+
504
+ class SchedulerConfig(BaseConfig):
505
+ """Learning rate scheduler configuration."""
506
+
507
+ scheduler_type: SchedulerEnum = Field(default=SchedulerEnum.LINEAR_WARMUP, description="Scheduler type")
508
+ warmup_steps: int = Field(default=0, description="Number of warmup steps (takes precedence over warmup_ratio)")
509
+ warmup_ratio: float = Field(default=0.1, description="Warmup as fraction of total steps (used if warmup_steps=0)")
510
+
511
+ # Cosine scheduler specific
512
+ num_cycles: float = Field(default=0.5, description="Number of cycles for cosine schedule")
513
+ last_epoch: int = Field(default=-1, description="Last epoch for scheduler")
514
+
515
+ # TaoNet 3-phase scheduler (warmup -> steady -> cosine decay)
516
+ steady_ratio: float = Field(
517
+ default=0.0,
518
+ description="Fraction of training steps at peak LR before cosine decay (0.0 = no steady phase). Only for cosineWarmup."
519
+ )
520
+ min_lr_ratio: float = Field(
521
+ default=0.0,
522
+ description="Minimum LR as fraction of peak LR at end of training (0.0 = decay to 0). Only for cosineWarmup."
523
+ )
524
+
525
+ @validator('warmup_ratio')
526
+ def validate_warmup_ratio(cls, v):
527
+ """Validate warmup ratio is between 0 and 1."""
528
+ if not 0 <= v <= 1:
529
+ raise ValueError("warmup_ratio must be between 0 and 1")
530
+ return v
531
+
532
+ @validator('steady_ratio')
533
+ def validate_steady_ratio(cls, v):
534
+ """Validate steady ratio is between 0 and 1."""
535
+ if not 0 <= v <= 1:
536
+ raise ValueError("steady_ratio must be between 0 and 1")
537
+ return v
538
+
539
+ @validator('min_lr_ratio')
540
+ def validate_min_lr_ratio(cls, v):
541
+ """Validate min_lr_ratio is between 0 and 1."""
542
+ if not 0 <= v <= 1:
543
+ raise ValueError("min_lr_ratio must be between 0 and 1")
544
+ return v
545
+
546
+ @validator('warmup_steps')
547
+ def validate_warmup_steps(cls, v):
548
+ """Validate warmup steps is non-negative."""
549
+ if v < 0:
550
+ raise ValueError("warmup_steps must be non-negative")
551
+ return v
552
+
553
+
554
+ class TrainingConfig(BaseConfig):
555
+ """Base training configuration shared across all modes."""
556
+
557
+ # Data and model
558
+ model: ModelConfig = Field(default_factory=ModelConfig, description="Model configuration")
559
+ dataset: DatasetConfig = Field(description="Dataset configuration")
560
+
561
+ # Training hyperparameters
562
+ batch_size: int = Field(default=32, description="Batch size per device")
563
+ num_epochs: int = Field(default=3, description="Number of training epochs")
564
+ max_steps: Optional[int] = Field(
565
+ default=None,
566
+ description="Maximum steps (overrides num_epochs if set)"
567
+ )
568
+ gradient_accumulation_steps: int = Field(
569
+ default=1,
570
+ description="Gradient accumulation steps"
571
+ )
572
+ max_grad_norm: float = Field(default=1.0, description="Gradient clipping max norm")
573
+
574
+ # Optimizer
575
+ optimizer: OptimizerConfig = Field(
576
+ default_factory=OptimizerConfig,
577
+ description="Optimizer configuration"
578
+ )
579
+
580
+ # Scheduler
581
+ scheduler: SchedulerConfig = Field(
582
+ default_factory=SchedulerConfig,
583
+ description="Learning rate scheduler configuration"
584
+ )
585
+
586
+ # Data type and device
587
+ dtype: DataTypeEnum = Field(
588
+ default=DataTypeEnum.BFLOAT16,
589
+ description="Training data type"
590
+ )
591
+ device: str = Field(default="cuda", description="Device to train on (cuda, cpu)")
592
+ seed: int = Field(default=42, description="Random seed")
593
+
594
+ # Checkpointing
595
+ checkpoint_dir: str = Field(default="checkpoints", description="Directory to save checkpoints")
596
+ checkpoint_path: Optional[str] = Field(
597
+ default=None,
598
+ description="Path to load pretrained checkpoint (for SFT/RL). If provided, loads weights before training starts."
599
+ )
600
+ save_every_steps: int = Field(default=500, description="Save checkpoint every N steps")
601
+ keep_last_n_checkpoints: int = Field(default=3, description="Keep only last N checkpoints")
602
+ save_best_model: bool = Field(default=True, description="Save best model based on validation loss")
603
+
604
+ # Validation
605
+ eval_every_steps: int = Field(default=500, description="Evaluate every N steps")
606
+ eval_samples: int = Field(default=1000, description="Number of validation samples")
607
+
608
+ # Logging
609
+ log_every_steps: int = Field(default=10, description="Log metrics every N steps")
610
+ aim_repo: str = Field(default=".aim", description="AimStack repository path")
611
+
612
+ # Misc
613
+ num_workers: int = Field(default=0, description="Number of DataLoader workers")
614
+ pin_memory: bool = Field(default=True, description="Pin memory for DataLoader")
615
+ use_compile: bool = Field(default=False, description="Use torch.compile (experimental)")
616
+
617
+ # Mode
618
+ mode: TrainingModeEnum = Field(default=TrainingModeEnum.PRETRAIN, description="Training mode")
619
+
620
+
621
+ # ============================================================================
622
+ # Stage-Specific Configs
623
+ # ============================================================================
624
+
625
+
626
+ class PretrainConfig(TrainingConfig):
627
+ """Configuration for pretraining."""
628
+
629
+ mode: Literal[TrainingModeEnum.PRETRAIN] = TrainingModeEnum.PRETRAIN
630
+
631
+ # Pretraining-specific
632
+ sequence_length: int = Field(default=1024, description="Sequence length for pretraining")
633
+
634
+
635
+ class SFTConfig(TrainingConfig):
636
+ """Configuration for supervised fine-tuning."""
637
+
638
+ mode: Literal[TrainingModeEnum.SFT] = TrainingModeEnum.SFT
639
+
640
+ # SFT-specific
641
+ response_loss_only: bool = Field(
642
+ default=True,
643
+ description="Only compute loss on response/assistant tokens (not instruction/user tokens). Uses -100 label masking."
644
+ )
645
+
646
+ # Multi-turn conversation role tokens
647
+ user_token: str = Field(
648
+ default="<user>",
649
+ description="Special token representing user/instruction role in conversations"
650
+ )
651
+ assistant_token: str = Field(
652
+ default="<assistant>",
653
+ description="Special token representing assistant/response role in conversations"
654
+ )
655
+
656
+
657
+ class RLConfig(TrainingConfig):
658
+ """Configuration for reinforcement learning training."""
659
+
660
+ mode: Literal[TrainingModeEnum.RL] = TrainingModeEnum.RL
661
+
662
+ # RL-specific
663
+ rl_method: RLMethodEnum = Field(
664
+ default=RLMethodEnum.PPO,
665
+ description="RL training method (PPO or DPO)"
666
+ )
667
+
668
+ # Reward model
669
+ reward_model_path: str = Field(description="Path to trained reward model checkpoint")
670
+
671
+ # PPO-specific
672
+ ppo_epochs: int = Field(default=4, description="PPO inner epochs")
673
+ ppo_clip_ratio: float = Field(default=0.2, description="PPO clipping ratio")
674
+ entropy_coeff: float = Field(default=0.01, description="Entropy bonus coefficient")
675
+ value_loss_coeff: float = Field(default=1.0, description="Value function loss coefficient")
676
+
677
+ # DPO-specific (Direct Preference Optimization)
678
+ dpo_beta: float = Field(default=0.1, description="DPO inverse temperature (beta)")
679
+
680
+ # Prompt distribution
681
+ prompt_dataset: Optional[DatasetConfig] = Field(
682
+ default=None,
683
+ description="Separate dataset for prompts (if different from main dataset)"
684
+ )
685
+ generation_max_length: int = Field(
686
+ default=256,
687
+ description="Maximum length for generated responses during RL"
688
+ )
689
+
690
+
691
+ # ============================================================================
692
+ # Factory function
693
+ # ============================================================================
694
+
695
+
696
+ def load_config(path: str | Path, mode: TrainingModeEnum | str) -> TrainingConfig:
697
+ """Load config file and return appropriate config class."""
698
+ if isinstance(mode, str):
699
+ mode = TrainingModeEnum(mode)
700
+
701
+ config_map = {
702
+ TrainingModeEnum.PRETRAIN: PretrainConfig,
703
+ TrainingModeEnum.SFT: SFTConfig,
704
+ TrainingModeEnum.RL: RLConfig,
705
+ }
706
+
707
+ config_class = config_map[mode]
708
+
709
+ path = Path(path)
710
+ if path.suffix == '.yaml' or path.suffix == '.yml':
711
+ return config_class.load_yaml(path)
712
+ elif path.suffix == '.json':
713
+ return config_class.load_json(path)
714
+ else:
715
+ raise ValueError(f"Unsupported config file format: {path.suffix}")
716
+
717
+
718
+ def load_tokenizer_config(path: str | Path) -> TokenizerConfig:
719
+ """Load tokenizer config from YAML or JSON file."""
720
+ path = Path(path)
721
+ if path.suffix == '.yaml' or path.suffix == '.yml':
722
+ return TokenizerConfig.load_yaml(path)
723
+ elif path.suffix == '.json':
724
+ return TokenizerConfig.load_json(path)
725
+ else:
726
+ raise ValueError(f"Unsupported config file format: {path.suffix}")