Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README.md +129 -95
- artifacts/configs/pretrain.yaml +86 -0
- artifacts/configs/sft.yaml +92 -0
- artifacts/diagnostics/activation_probe_pretrain_final.json +118 -0
- artifacts/diagnostics/generation_samples_pretrain_final.json +23 -0
- artifacts/diagnostics/generation_samples_sft_final.json +23 -0
- artifacts/gpu_telemetry_nvidia_smi.csv +0 -0
- artifacts/loss_curve_summary.json +26 -0
- artifacts/model_card.json +13 -0
- artifacts/run.sh +55 -0
- artifacts/run_plan.json +19 -0
- artifacts/status.json +1 -0
- code/TaoTrain/.gitignore +150 -0
- code/TaoTrain/README.md +423 -0
- code/TaoTrain/configs/pretrain.yaml +138 -0
- code/TaoTrain/configs/rl_dpo.yaml +60 -0
- code/TaoTrain/configs/sft.yaml +93 -0
- code/TaoTrain/configs/ssm_pretrain.yaml +87 -0
- code/TaoTrain/configs/tokenizer.yaml +44 -0
- code/TaoTrain/configs/tokenizer_taodata_pilot.yaml +22 -0
- code/TaoTrain/configs/yarn_pretrain.yaml +181 -0
- code/TaoTrain/data/sample_data.jsonl +10 -0
- code/TaoTrain/data/sample_sft_multi.jsonl +4 -0
- code/TaoTrain/data/sample_sft_single.jsonl +5 -0
- code/TaoTrain/pyproject.toml +65 -0
- code/TaoTrain/scripts/benchmark_taonet_real_tokens.py +1034 -0
- code/TaoTrain/scripts/benchmark_taonet_token_variants.py +516 -0
- code/TaoTrain/scripts/diagnostics/activation_probe.py +160 -0
- code/TaoTrain/scripts/diagnostics/generate_checkpoint_samples.py +134 -0
- code/TaoTrain/scripts/diagnostics/sft_sanity_check.py +300 -0
- code/TaoTrain/scripts/profile_taonet_components.py +393 -0
- code/TaoTrain/scripts/remote/job_status.sh +36 -0
- code/TaoTrain/scripts/remote/run_200m_base_suite.sh +149 -0
- code/TaoTrain/scripts/remote/run_200m_branch_only_chat.sh +329 -0
- code/TaoTrain/scripts/remote/run_200m_until_selection.sh +81 -0
- code/TaoTrain/scripts/remote/run_pre_200m_branch_only_gate.sh +269 -0
- code/TaoTrain/scripts/remote/run_pre_200m_stability_gate.sh +267 -0
- code/TaoTrain/scripts/remote/submit_200m_base_suite.sh +16 -0
- code/TaoTrain/scripts/remote/submit_200m_branch_only_chat.sh +32 -0
- code/TaoTrain/scripts/remote/submit_200m_until_selection.sh +28 -0
- code/TaoTrain/scripts/remote/submit_detached_job.sh +118 -0
- code/TaoTrain/scripts/remote/submit_pre_200m_branch_only_gate.sh +33 -0
- code/TaoTrain/scripts/remote/submit_pre_200m_stability_gate.sh +32 -0
- code/TaoTrain/scripts/remote/submit_ssm_improvement_sweep.sh +71 -0
- code/TaoTrain/scripts/summarize_taonet_benchmark_suite.py +138 -0
- code/TaoTrain/src/taoTrain.egg-info/dependency_links.txt +1 -0
- code/TaoTrain/src/taoTrain.egg-info/entry_points.txt +4 -0
- code/TaoTrain/src/taoTrain/__init__.py +8 -0
- code/TaoTrain/src/taoTrain/cli.py +387 -0
- code/TaoTrain/src/taoTrain/config.py +726 -0
README.md
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
license: mit
|
| 3 |
language:
|
| 4 |
- en
|
| 5 |
-
library_name:
|
| 6 |
pipeline_tag: text-generation
|
| 7 |
tags:
|
| 8 |
- taonet
|
|
@@ -11,7 +11,8 @@ tags:
|
|
| 11 |
- state-space-model
|
| 12 |
- dplr
|
| 13 |
- pytorch
|
| 14 |
-
-
|
|
|
|
| 15 |
- text-generation
|
| 16 |
- experimental
|
| 17 |
datasets:
|
|
@@ -20,15 +21,93 @@ datasets:
|
|
| 20 |
|
| 21 |
# TaoNet-mini-T2
|
| 22 |
|
| 23 |
-
TaoNet-mini-T2 is an experimental 196M-parameter TaoNet
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
## Model Details
|
| 28 |
|
| 29 |
| Field | Value |
|
| 30 |
-
|---|---|
|
| 31 |
-
| Model family | TaoNet / Taotern SSM |
|
| 32 |
| Architecture | `taonet_ssm` |
|
| 33 |
| Candidate | `pure_ssm_196m_branch_rms_only` |
|
| 34 |
| Parameters | 196,573,128 |
|
|
@@ -45,20 +124,17 @@ This release is intended for research backup, reproducibility, and deployment in
|
|
| 45 |
| Local shift | Enabled, per-channel |
|
| 46 |
| Branch RMS norm | Enabled |
|
| 47 |
|
| 48 |
-
##
|
| 49 |
-
|
| 50 |
-
Download and unzip:
|
| 51 |
-
|
| 52 |
-
```text
|
| 53 |
-
taotern-200m-branch-only-chat-20260514_handoff.zip
|
| 54 |
-
```
|
| 55 |
-
|
| 56 |
-
The zip contains:
|
| 57 |
|
| 58 |
```text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
model/
|
| 60 |
final_model.pt # SFT final checkpoint
|
| 61 |
-
pretrain_final_model.pt #
|
| 62 |
tokenizer/
|
| 63 |
tokenizer.model
|
| 64 |
tokenizer.vocab
|
|
@@ -69,100 +145,64 @@ code/
|
|
| 69 |
artifacts/
|
| 70 |
configs/
|
| 71 |
diagnostics/
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
loss_curve_summary.json
|
| 75 |
-
chat_ssm_fixed.py
|
| 76 |
-
RUN_TEST_WINDOWS.cmd
|
| 77 |
-
setup_windows.ps1
|
| 78 |
-
setup_linux.sh
|
| 79 |
-
run_chat_fixed.ps1
|
| 80 |
-
run_chat_fixed.sh
|
| 81 |
-
CHECKSUMS.sha256
|
| 82 |
-
README_HANDOFF.md
|
| 83 |
```
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
```text
|
| 88 |
-
sha256=5dd45e5d236f37101d183cadb92c5a13ea0aa574a24f578faa6cc8ee7ed166d6
|
| 89 |
-
```
|
| 90 |
-
|
| 91 |
-
## Quick Start
|
| 92 |
|
| 93 |
-
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
```text
|
| 99 |
-
RUN_TEST_WINDOWS.cmd
|
| 100 |
```
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
Manual Windows setup:
|
| 105 |
|
| 106 |
```powershell
|
| 107 |
-
powershell -ExecutionPolicy Bypass -File .\
|
| 108 |
-
powershell -ExecutionPolicy Bypass -File .\run_chat_fixed.ps1
|
| 109 |
```
|
| 110 |
|
| 111 |
-
|
| 112 |
|
| 113 |
-
```
|
| 114 |
-
powershell -ExecutionPolicy Bypass -File .\setup_windows.ps1 -TorchFlavor cu126
|
| 115 |
-
```
|
| 116 |
|
| 117 |
-
|
| 118 |
|
| 119 |
-
```
|
| 120 |
-
|
|
|
|
|
|
|
| 121 |
```
|
| 122 |
|
| 123 |
-
|
| 124 |
|
| 125 |
-
|
| 126 |
-
unzip taotern-200m-branch-only-chat-20260514_handoff.zip
|
| 127 |
-
cd taotern-200m-branch-only-chat-20260514
|
| 128 |
-
chmod +x setup_linux.sh run_chat_fixed.sh
|
| 129 |
-
./setup_linux.sh cu128
|
| 130 |
-
./run_chat_fixed.sh
|
| 131 |
-
```
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
```bash
|
| 136 |
-
./setup_linux.sh cu126
|
| 137 |
-
```
|
| 138 |
-
|
| 139 |
-
## Recommended Inference Mode
|
| 140 |
-
|
| 141 |
-
The checkpoint was trained with convolutional SSM execution. For chat testing, this bundle uses a temporary RepoBridge-style inference fix:
|
| 142 |
|
| 143 |
```text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
ssm_finite_tail_correction=true
|
| 145 |
-
|
| 146 |
-
default_checkpoint=model/pretrain_final_model.pt
|
| 147 |
-
decode=greedy
|
| 148 |
-
top_p=0.85
|
| 149 |
-
repetition_penalty=1.2
|
| 150 |
-
max_new_tokens=64
|
| 151 |
-
```
|
| 152 |
-
|
| 153 |
-
Run it directly:
|
| 154 |
-
|
| 155 |
-
```bash
|
| 156 |
-
python chat_ssm_fixed.py --interactive
|
| 157 |
```
|
| 158 |
|
| 159 |
-
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
-
|
| 166 |
|
| 167 |
## Training Summary
|
| 168 |
|
|
@@ -172,8 +212,6 @@ Run ID:
|
|
| 172 |
taotern-200m-branch-only-chat-20260514
|
| 173 |
```
|
| 174 |
|
| 175 |
-
Training shape:
|
| 176 |
-
|
| 177 |
| Stage | Value |
|
| 178 |
|---|---:|
|
| 179 |
| Pretrain token positions | 4,000,000,000 |
|
|
@@ -204,17 +242,14 @@ This model is intended for:
|
|
| 204 |
|
| 205 |
- Taotern/TaoNet SSM research
|
| 206 |
- checkpoint backup and reproducibility
|
| 207 |
-
- deployment experiments
|
| 208 |
- studying recurrent SSM inference behavior
|
| 209 |
|
| 210 |
-
It is not currently a drop-in Transformers model and is not intended as a polished production chatbot.
|
| 211 |
-
|
| 212 |
## Limitations
|
| 213 |
|
| 214 |
-
- Experimental model quality; validate before use.
|
| 215 |
-
- Requires
|
| 216 |
- The recommended chat path depends on an inference-time SSM override.
|
| 217 |
-
- Not currently packaged as `AutoModelForCausalLM`.
|
| 218 |
- CPU inference is expected to be very slow.
|
| 219 |
- English-focused pilot data/tokenizer.
|
| 220 |
|
|
@@ -232,4 +267,3 @@ It is not currently a drop-in Transformers model and is not intended as a polish
|
|
| 232 |
## Related
|
| 233 |
|
| 234 |
- [TaoTern/TaoNet-pico-T1](https://huggingface.co/TaoTern/TaoNet-pico-T1)
|
| 235 |
-
|
|
|
|
| 2 |
license: mit
|
| 3 |
language:
|
| 4 |
- en
|
| 5 |
+
library_name: transformers
|
| 6 |
pipeline_tag: text-generation
|
| 7 |
tags:
|
| 8 |
- taonet
|
|
|
|
| 11 |
- state-space-model
|
| 12 |
- dplr
|
| 13 |
- pytorch
|
| 14 |
+
- transformers
|
| 15 |
+
- custom_code
|
| 16 |
- text-generation
|
| 17 |
- experimental
|
| 18 |
datasets:
|
|
|
|
| 21 |
|
| 22 |
# TaoNet-mini-T2
|
| 23 |
|
| 24 |
+
TaoNet-mini-T2 is an experimental 196M-parameter TaoNet language model using a Taotern/Gamma DPLR state-space model (SSM) sequence core instead of attention. The repository includes the full training handoff package, but the recommended inference path is now Hugging Face `transformers` remote code:
|
| 25 |
|
| 26 |
+
```python
|
| 27 |
+
AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True)
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
The default `transformers` loader downloads `model/pretrain_final_model.pt` and applies the RepoBridge chat-quality fix: `ssm_finite_tail_correction=True` and `ssm_kernel_mode="recurrent"`.
|
| 31 |
+
|
| 32 |
+
## Quick Start
|
| 33 |
+
|
| 34 |
+
Install runtime dependencies:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
pip install torch transformers sentencepiece huggingface_hub pydantic pydantic-settings pyyaml numpy
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
For the private review repo, log in first:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
hf auth login
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Run generation from Python:
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
import time
|
| 50 |
+
import torch
|
| 51 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 52 |
+
|
| 53 |
+
MODEL_NAME = "TaoTern/TaoNet-mini-T2"
|
| 54 |
+
|
| 55 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 56 |
+
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 57 |
+
|
| 58 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
| 59 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 60 |
+
MODEL_NAME,
|
| 61 |
+
trust_remote_code=True,
|
| 62 |
+
torch_dtype=dtype,
|
| 63 |
+
).to(device)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def generate_text(prompt, max_new_tokens=64, temperature=0.7, top_p=0.85):
|
| 67 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 68 |
+
inputs = {key: value.to(device) for key, value in inputs.items()}
|
| 69 |
+
|
| 70 |
+
start_time = time.time()
|
| 71 |
+
with torch.inference_mode():
|
| 72 |
+
outputs = model.generate(
|
| 73 |
+
**inputs,
|
| 74 |
+
max_new_tokens=max_new_tokens,
|
| 75 |
+
temperature=temperature,
|
| 76 |
+
top_p=top_p,
|
| 77 |
+
repetition_penalty=1.2,
|
| 78 |
+
do_sample=True,
|
| 79 |
+
use_cache=False,
|
| 80 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 81 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 82 |
+
)
|
| 83 |
+
elapsed_time = time.time() - start_time
|
| 84 |
+
|
| 85 |
+
new_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
|
| 86 |
+
tokens_per_second = new_tokens / elapsed_time if elapsed_time > 0 else 0.0
|
| 87 |
+
completion = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 88 |
+
return completion, tokens_per_second
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
text, tps = generate_text("Fruit is now expensive so we should")
|
| 93 |
+
print(text)
|
| 94 |
+
print(f"\nTokens per second: {tps:.2f}")
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
To load the SFT final checkpoint instead of the default pretrain checkpoint:
|
| 98 |
+
|
| 99 |
+
```python
|
| 100 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 101 |
+
"TaoTern/TaoNet-mini-T2",
|
| 102 |
+
trust_remote_code=True,
|
| 103 |
+
checkpoint_name="final_model.pt",
|
| 104 |
+
)
|
| 105 |
+
```
|
| 106 |
|
| 107 |
## Model Details
|
| 108 |
|
| 109 |
| Field | Value |
|
| 110 |
+
|---|---:|
|
|
|
|
| 111 |
| Architecture | `taonet_ssm` |
|
| 112 |
| Candidate | `pure_ssm_196m_branch_rms_only` |
|
| 113 |
| Parameters | 196,573,128 |
|
|
|
|
| 124 |
| Local shift | Enabled, per-channel |
|
| 125 |
| Branch RMS norm | Enabled |
|
| 126 |
|
| 127 |
+
## Repository Layout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
```text
|
| 130 |
+
config.json
|
| 131 |
+
configuration_taonet_mini_t2.py
|
| 132 |
+
modeling_taonet_mini_t2.py
|
| 133 |
+
tokenization_taonet_mini_t2.py
|
| 134 |
+
tokenizer.model
|
| 135 |
model/
|
| 136 |
final_model.pt # SFT final checkpoint
|
| 137 |
+
pretrain_final_model.pt # default checkpoint for HF inference
|
| 138 |
tokenizer/
|
| 139 |
tokenizer.model
|
| 140 |
tokenizer.vocab
|
|
|
|
| 145 |
artifacts/
|
| 146 |
configs/
|
| 147 |
diagnostics/
|
| 148 |
+
chat_ssm_fixed.py # legacy local fixed-chat CLI
|
| 149 |
+
eval_lm_eval.py # local lm-eval harness wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
```
|
| 151 |
|
| 152 |
+
## Upload Notes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
+
This repo contains two multi-GB checkpoint files, so prefer the resumable large-folder uploader instead of the normal single-commit upload command:
|
| 155 |
|
| 156 |
+
```bash
|
| 157 |
+
hf upload-large-folder TaoTern/TaoNet-mini-T2 . --repo-type model --private
|
|
|
|
|
|
|
|
|
|
| 158 |
```
|
| 159 |
|
| 160 |
+
On Windows, from the repo folder:
|
|
|
|
|
|
|
| 161 |
|
| 162 |
```powershell
|
| 163 |
+
powershell -ExecutionPolicy Bypass -File .\upload_large_folder.ps1
|
|
|
|
| 164 |
```
|
| 165 |
|
| 166 |
+
## Inference Notes
|
| 167 |
|
| 168 |
+
The training config used `ssm_finite_tail_correction=False` and `ssm_kernel_mode="conv"`. That path is fast for full-sequence training/evaluation but produced poor chat samples in the recovered workflow.
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
The `transformers` wrapper defaults to:
|
| 171 |
|
| 172 |
+
```text
|
| 173 |
+
ssm_finite_tail_correction=True
|
| 174 |
+
ssm_kernel_mode=recurrent
|
| 175 |
+
checkpoint=model/pretrain_final_model.pt
|
| 176 |
```
|
| 177 |
|
| 178 |
+
For fast benchmark scoring, use the included `eval_lm_eval.py` script with `--ssm-kernel-mode conv --finite-tail`.
|
| 179 |
|
| 180 |
+
## LM Evaluation Harness Benchmark
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
+
Settings:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
```text
|
| 185 |
+
library=lm-eval-harness
|
| 186 |
+
checkpoint=model/pretrain_final_model.pt
|
| 187 |
+
num_fewshot=0
|
| 188 |
+
limit=100
|
| 189 |
+
ssm_kernel_mode=conv
|
| 190 |
ssm_finite_tail_correction=true
|
| 191 |
+
eval_batch_size=8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
```
|
| 193 |
|
| 194 |
+
Results:
|
| 195 |
|
| 196 |
+
| Task | Primary score |
|
| 197 |
+
|---|---:|
|
| 198 |
+
| HellaSwag | 0.3300 |
|
| 199 |
+
| ARC Easy | 0.3400 |
|
| 200 |
+
| ARC Challenge | 0.2200 |
|
| 201 |
+
| PIQA | 0.4400 |
|
| 202 |
+
| Winogrande | 0.5300 |
|
| 203 |
+
| Mean primary score | 0.3720 |
|
| 204 |
|
| 205 |
+
These are limit-100 smoke benchmark numbers for review, not full leaderboard results.
|
| 206 |
|
| 207 |
## Training Summary
|
| 208 |
|
|
|
|
| 212 |
taotern-200m-branch-only-chat-20260514
|
| 213 |
```
|
| 214 |
|
|
|
|
|
|
|
| 215 |
| Stage | Value |
|
| 216 |
|---|---:|
|
| 217 |
| Pretrain token positions | 4,000,000,000 |
|
|
|
|
| 242 |
|
| 243 |
- Taotern/TaoNet SSM research
|
| 244 |
- checkpoint backup and reproducibility
|
| 245 |
+
- deployment experiments with custom Hugging Face remote code
|
| 246 |
- studying recurrent SSM inference behavior
|
| 247 |
|
|
|
|
|
|
|
| 248 |
## Limitations
|
| 249 |
|
| 250 |
+
- Experimental model quality; validate outputs before use.
|
| 251 |
+
- Requires `trust_remote_code=True` because the architecture is not part of upstream `transformers`.
|
| 252 |
- The recommended chat path depends on an inference-time SSM override.
|
|
|
|
| 253 |
- CPU inference is expected to be very slow.
|
| 254 |
- English-focused pilot data/tokenizer.
|
| 255 |
|
|
|
|
| 267 |
## Related
|
| 268 |
|
| 269 |
- [TaoTern/TaoNet-pico-T1](https://huggingface.co/TaoTern/TaoNet-pico-T1)
|
|
|
artifacts/configs/pretrain.yaml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
architecture_type: taonet_ssm
|
| 3 |
+
vocab_size: 8192
|
| 4 |
+
hidden_dim: 1024
|
| 5 |
+
num_layers: 18
|
| 6 |
+
num_heads: 8
|
| 7 |
+
max_seq_length: 512
|
| 8 |
+
d_latent_kv: 768
|
| 9 |
+
d_rope: 128
|
| 10 |
+
hidden_dim_ff: 3072
|
| 11 |
+
dropout: 0.0
|
| 12 |
+
gqa_groups: 1
|
| 13 |
+
use_factorized_embedding: false
|
| 14 |
+
d_embed_rank: 96
|
| 15 |
+
init_std: 0.02
|
| 16 |
+
ssm_core: dplr
|
| 17 |
+
ssm_hidden_dim: 32
|
| 18 |
+
ssm_mixer_dim: 256
|
| 19 |
+
ssm_num_lanes: 2
|
| 20 |
+
ssm_lane_combine: channel
|
| 21 |
+
ssm_lane_mode: split
|
| 22 |
+
ssm_split_mix: none
|
| 23 |
+
ssm_rank: 1
|
| 24 |
+
ssm_max_low_rank_scale: 0.1
|
| 25 |
+
ssm_finite_tail_correction: false
|
| 26 |
+
ssm_discretization: bilinear
|
| 27 |
+
ssm_kernel_mode: conv
|
| 28 |
+
ssm_kernel_threshold: 64
|
| 29 |
+
ssm_dt_min: 1e-3
|
| 30 |
+
ssm_dt_max: 1e-1
|
| 31 |
+
ssm_dt_init: 1e-2
|
| 32 |
+
ssm_use_d: true
|
| 33 |
+
ssm_activation: gelu
|
| 34 |
+
ssm_gate: true
|
| 35 |
+
ssm_input_gate: true
|
| 36 |
+
ssm_gate_type: channel
|
| 37 |
+
ssm_use_padding_mask: false
|
| 38 |
+
ssm_layer_scale_init: 0.1
|
| 39 |
+
ssm_branch_rms_norm: true
|
| 40 |
+
block_residual_rms_norm: false
|
| 41 |
+
|
| 42 |
+
ssm_local_shift: true
|
| 43 |
+
ssm_local_shift_init: 0.1
|
| 44 |
+
ssm_local_shift_per_channel: true
|
| 45 |
+
|
| 46 |
+
dataset:
|
| 47 |
+
local: true
|
| 48 |
+
jsonl_path: /home/student/Data/TaoData/pretrain.jsonl
|
| 49 |
+
text_field: text
|
| 50 |
+
tokenizer_type: sentencepiece
|
| 51 |
+
tokenizer_path: /home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model
|
| 52 |
+
samples_per_chunk: 2000
|
| 53 |
+
tokenizer_threads: 8
|
| 54 |
+
|
| 55 |
+
sequence_length: 512
|
| 56 |
+
batch_size: 8
|
| 57 |
+
num_epochs: 100000
|
| 58 |
+
max_steps: 976563
|
| 59 |
+
gradient_accumulation_steps: 1
|
| 60 |
+
max_grad_norm: 1.0
|
| 61 |
+
|
| 62 |
+
optimizer:
|
| 63 |
+
optimizer_type: adamw
|
| 64 |
+
learning_rate: 0.0008
|
| 65 |
+
weight_decay: 0.01
|
| 66 |
+
betas: [0.9, 0.999]
|
| 67 |
+
eps: 1e-8
|
| 68 |
+
|
| 69 |
+
scheduler:
|
| 70 |
+
scheduler_type: constant
|
| 71 |
+
warmup_steps: 1000
|
| 72 |
+
warmup_ratio: 0.0
|
| 73 |
+
|
| 74 |
+
dtype: bfloat16
|
| 75 |
+
device: cuda
|
| 76 |
+
checkpoint_dir: /home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/pretrain
|
| 77 |
+
save_every_steps: 100000
|
| 78 |
+
save_best_model: false
|
| 79 |
+
keep_last_n_checkpoints: 3
|
| 80 |
+
eval_every_steps: 100000
|
| 81 |
+
eval_samples: 32
|
| 82 |
+
log_every_steps: 100
|
| 83 |
+
aim_repo: /home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/.aim-pretrain
|
| 84 |
+
seed: 43
|
| 85 |
+
num_workers: 0
|
| 86 |
+
pin_memory: true
|
artifacts/configs/sft.yaml
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
architecture_type: taonet_ssm
|
| 3 |
+
vocab_size: 8192
|
| 4 |
+
hidden_dim: 1024
|
| 5 |
+
num_layers: 18
|
| 6 |
+
num_heads: 8
|
| 7 |
+
max_seq_length: 512
|
| 8 |
+
d_latent_kv: 768
|
| 9 |
+
d_rope: 128
|
| 10 |
+
hidden_dim_ff: 3072
|
| 11 |
+
dropout: 0.0
|
| 12 |
+
gqa_groups: 1
|
| 13 |
+
use_factorized_embedding: false
|
| 14 |
+
d_embed_rank: 96
|
| 15 |
+
init_std: 0.02
|
| 16 |
+
ssm_core: dplr
|
| 17 |
+
ssm_hidden_dim: 32
|
| 18 |
+
ssm_mixer_dim: 256
|
| 19 |
+
ssm_num_lanes: 2
|
| 20 |
+
ssm_lane_combine: channel
|
| 21 |
+
ssm_lane_mode: split
|
| 22 |
+
ssm_split_mix: none
|
| 23 |
+
ssm_rank: 1
|
| 24 |
+
ssm_max_low_rank_scale: 0.1
|
| 25 |
+
ssm_finite_tail_correction: false
|
| 26 |
+
ssm_discretization: bilinear
|
| 27 |
+
ssm_kernel_mode: conv
|
| 28 |
+
ssm_kernel_threshold: 64
|
| 29 |
+
ssm_dt_min: 1e-3
|
| 30 |
+
ssm_dt_max: 1e-1
|
| 31 |
+
ssm_dt_init: 1e-2
|
| 32 |
+
ssm_use_d: true
|
| 33 |
+
ssm_activation: gelu
|
| 34 |
+
ssm_gate: true
|
| 35 |
+
ssm_input_gate: true
|
| 36 |
+
ssm_gate_type: channel
|
| 37 |
+
ssm_use_padding_mask: false
|
| 38 |
+
ssm_layer_scale_init: 0.1
|
| 39 |
+
ssm_branch_rms_norm: true
|
| 40 |
+
block_residual_rms_norm: false
|
| 41 |
+
|
| 42 |
+
ssm_local_shift: true
|
| 43 |
+
ssm_local_shift_init: 0.1
|
| 44 |
+
ssm_local_shift_per_channel: true
|
| 45 |
+
|
| 46 |
+
dataset:
|
| 47 |
+
split: train
|
| 48 |
+
instruction_column: input
|
| 49 |
+
response_column: output
|
| 50 |
+
local: true
|
| 51 |
+
jsonl_path: /home/student/Data/TaoData/sft.jsonl
|
| 52 |
+
samples_per_chunk: 2000
|
| 53 |
+
tokenizer_type: sentencepiece
|
| 54 |
+
tokenizer_path: /home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model
|
| 55 |
+
tokenizer_threads: 8
|
| 56 |
+
|
| 57 |
+
checkpoint_path: /home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/pretrain/final_model.pt
|
| 58 |
+
user_token: "<user>"
|
| 59 |
+
assistant_token: "<assistant>"
|
| 60 |
+
response_loss_only: true
|
| 61 |
+
|
| 62 |
+
batch_size: 8
|
| 63 |
+
num_epochs: 100000
|
| 64 |
+
max_steps: 50000
|
| 65 |
+
gradient_accumulation_steps: 1
|
| 66 |
+
max_grad_norm: 1.0
|
| 67 |
+
|
| 68 |
+
optimizer:
|
| 69 |
+
optimizer_type: adamw
|
| 70 |
+
learning_rate: 0.00005
|
| 71 |
+
weight_decay: 0.0
|
| 72 |
+
betas: [0.9, 0.999]
|
| 73 |
+
eps: 1e-8
|
| 74 |
+
|
| 75 |
+
scheduler:
|
| 76 |
+
scheduler_type: constant
|
| 77 |
+
warmup_steps: 100
|
| 78 |
+
warmup_ratio: 0.0
|
| 79 |
+
|
| 80 |
+
dtype: bfloat16
|
| 81 |
+
device: cuda
|
| 82 |
+
checkpoint_dir: /home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/sft
|
| 83 |
+
save_every_steps: 10000
|
| 84 |
+
save_best_model: false
|
| 85 |
+
keep_last_n_checkpoints: 3
|
| 86 |
+
eval_every_steps: 10000
|
| 87 |
+
eval_samples: 32
|
| 88 |
+
log_every_steps: 20
|
| 89 |
+
aim_repo: /home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/.aim-sft
|
| 90 |
+
seed: 44
|
| 91 |
+
num_workers: 0
|
| 92 |
+
pin_memory: true
|
artifacts/diagnostics/activation_probe_pretrain_final.json
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"checkpoint": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/pretrain/final_model.pt",
|
| 3 |
+
"loss": 2.8459982872009277,
|
| 4 |
+
"batch_size": 2,
|
| 5 |
+
"seq_len": 512,
|
| 6 |
+
"device": "cuda",
|
| 7 |
+
"dtype": "torch.bfloat16",
|
| 8 |
+
"layers": {
|
| 9 |
+
"blocks.0": {
|
| 10 |
+
"numel": 1048576,
|
| 11 |
+
"finite": 1048576,
|
| 12 |
+
"rms": 2.9217934608459473,
|
| 13 |
+
"max_abs": 62.90830612182617
|
| 14 |
+
},
|
| 15 |
+
"blocks.1": {
|
| 16 |
+
"numel": 1048576,
|
| 17 |
+
"finite": 1048576,
|
| 18 |
+
"rms": 3.4865853786468506,
|
| 19 |
+
"max_abs": 63.050193786621094
|
| 20 |
+
},
|
| 21 |
+
"blocks.2": {
|
| 22 |
+
"numel": 1048576,
|
| 23 |
+
"finite": 1048576,
|
| 24 |
+
"rms": 4.092358589172363,
|
| 25 |
+
"max_abs": 98.85921478271484
|
| 26 |
+
},
|
| 27 |
+
"blocks.3": {
|
| 28 |
+
"numel": 1048576,
|
| 29 |
+
"finite": 1048576,
|
| 30 |
+
"rms": 4.814863681793213,
|
| 31 |
+
"max_abs": 160.82785034179688
|
| 32 |
+
},
|
| 33 |
+
"blocks.4": {
|
| 34 |
+
"numel": 1048576,
|
| 35 |
+
"finite": 1048576,
|
| 36 |
+
"rms": 6.032325267791748,
|
| 37 |
+
"max_abs": 237.8623504638672
|
| 38 |
+
},
|
| 39 |
+
"blocks.5": {
|
| 40 |
+
"numel": 1048576,
|
| 41 |
+
"finite": 1048576,
|
| 42 |
+
"rms": 7.5034403800964355,
|
| 43 |
+
"max_abs": 324.9512939453125
|
| 44 |
+
},
|
| 45 |
+
"blocks.6": {
|
| 46 |
+
"numel": 1048576,
|
| 47 |
+
"finite": 1048576,
|
| 48 |
+
"rms": 8.303068161010742,
|
| 49 |
+
"max_abs": 345.4449462890625
|
| 50 |
+
},
|
| 51 |
+
"blocks.7": {
|
| 52 |
+
"numel": 1048576,
|
| 53 |
+
"finite": 1048576,
|
| 54 |
+
"rms": 8.911171913146973,
|
| 55 |
+
"max_abs": 337.6418151855469
|
| 56 |
+
},
|
| 57 |
+
"blocks.8": {
|
| 58 |
+
"numel": 1048576,
|
| 59 |
+
"finite": 1048576,
|
| 60 |
+
"rms": 10.293173789978027,
|
| 61 |
+
"max_abs": 343.7690124511719
|
| 62 |
+
},
|
| 63 |
+
"blocks.9": {
|
| 64 |
+
"numel": 1048576,
|
| 65 |
+
"finite": 1048576,
|
| 66 |
+
"rms": 11.591143608093262,
|
| 67 |
+
"max_abs": 361.6515808105469
|
| 68 |
+
},
|
| 69 |
+
"blocks.10": {
|
| 70 |
+
"numel": 1048576,
|
| 71 |
+
"finite": 1048576,
|
| 72 |
+
"rms": 13.886137962341309,
|
| 73 |
+
"max_abs": 401.029052734375
|
| 74 |
+
},
|
| 75 |
+
"blocks.11": {
|
| 76 |
+
"numel": 1048576,
|
| 77 |
+
"finite": 1048576,
|
| 78 |
+
"rms": 15.372880935668945,
|
| 79 |
+
"max_abs": 406.81036376953125
|
| 80 |
+
},
|
| 81 |
+
"blocks.12": {
|
| 82 |
+
"numel": 1048576,
|
| 83 |
+
"finite": 1048576,
|
| 84 |
+
"rms": 18.466873168945312,
|
| 85 |
+
"max_abs": 541.068603515625
|
| 86 |
+
},
|
| 87 |
+
"blocks.13": {
|
| 88 |
+
"numel": 1048576,
|
| 89 |
+
"finite": 1048576,
|
| 90 |
+
"rms": 20.376676559448242,
|
| 91 |
+
"max_abs": 553.8697509765625
|
| 92 |
+
},
|
| 93 |
+
"blocks.14": {
|
| 94 |
+
"numel": 1048576,
|
| 95 |
+
"finite": 1048576,
|
| 96 |
+
"rms": 23.877737045288086,
|
| 97 |
+
"max_abs": 652.4536743164062
|
| 98 |
+
},
|
| 99 |
+
"blocks.15": {
|
| 100 |
+
"numel": 1048576,
|
| 101 |
+
"finite": 1048576,
|
| 102 |
+
"rms": 28.56619644165039,
|
| 103 |
+
"max_abs": 874.2588500976562
|
| 104 |
+
},
|
| 105 |
+
"blocks.16": {
|
| 106 |
+
"numel": 1048576,
|
| 107 |
+
"finite": 1048576,
|
| 108 |
+
"rms": 35.949851989746094,
|
| 109 |
+
"max_abs": 1143.6116943359375
|
| 110 |
+
},
|
| 111 |
+
"blocks.17": {
|
| 112 |
+
"numel": 1048576,
|
| 113 |
+
"finite": 1048576,
|
| 114 |
+
"rms": 45.9725227355957,
|
| 115 |
+
"max_abs": 2560.034423828125
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
}
|
artifacts/diagnostics/generation_samples_pretrain_final.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"checkpoint": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/pretrain/final_model.pt",
|
| 3 |
+
"tokenizer_path": "/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model",
|
| 4 |
+
"device": "cuda",
|
| 5 |
+
"dtype": "torch.bfloat16",
|
| 6 |
+
"max_new_tokens": 160,
|
| 7 |
+
"temperature": 0.8,
|
| 8 |
+
"top_p": 0.9,
|
| 9 |
+
"samples": [
|
| 10 |
+
{
|
| 11 |
+
"prompt": "The purpose of artificial intelligence is",
|
| 12 |
+
"completion": "to enable human creativity to be a kind of artifact of politics - and sometimes to do that the principle of human dignity cannot be denied without the distinction of two; but the only thing that gives those who rely on it the ability to be free to decide the right thing.In some cases, the majority of the rules about how to use such knowledge are not clearly understood by the law. This leads to the generation of the same arguments, which leads to the mistake of the expert who sought to apply it to everything else.Well, to comply with the human dignity. There is some good news. There are some types of discrimination. This is because it is always a death sentence. Many people are not even able to get to court."
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"prompt": "In a small village,",
|
| 16 |
+
"completion": "you can live alone, but sometimes you might have to leave the house without a roof. Most of these city residents get two hours of sleep on most of the time, while many do.When it comes to school:Times help.Most of the time, wind is always the most important (or the least important): being well rested during the day.When the weather gets hot, the symptoms of insomnia can appear. Most of the time it just doesn't. Most people get enough sleep.Typically, we don't need to get more than 2 hours.Why are there such a large number of people getting the same kind of sleep?There are some factors that make it difficult to identify as having sleep problems. Some people do get a few"
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"prompt": "<user>Hello, who are you?<assistant>",
|
| 20 |
+
"completion": "24.2 (2) 1:10Though many households still live without electricity, most people who live close to the power lines will do so to the extent that they have only a small amount of electricity at the end of the day (33%).(3) The average household electricity rate is higher for the two types of households at 32%, with 31% of the 2005-06 peaks.7Take home message:- Unrealistically high societies. This is one factor that makes getting it to such a huge difference. This means that one has to make some very poor choices when it comes to one's overall lifestyle. Many cities do not have a proper gas taxes;"
|
| 21 |
+
}
|
| 22 |
+
]
|
| 23 |
+
}
|
artifacts/diagnostics/generation_samples_sft_final.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"checkpoint": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/sft/final_model.pt",
|
| 3 |
+
"tokenizer_path": "/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model",
|
| 4 |
+
"device": "cuda",
|
| 5 |
+
"dtype": "torch.bfloat16",
|
| 6 |
+
"max_new_tokens": 220,
|
| 7 |
+
"temperature": 0.7,
|
| 8 |
+
"top_p": 0.9,
|
| 9 |
+
"samples": [
|
| 10 |
+
{
|
| 11 |
+
"prompt": "<user>Hello, who are you?<assistant>",
|
| 12 |
+
"completion": "A tip is usually a net, a pack is a scavenger, which means that the initial effort is the delta, the body, the nav, is the ad. ⁇ ** Also, typically, a person typically becomes anemic and always begins to move around, which typically involves movement, which typically involves some movement, such as movement, which can also affect movement. ⁇ **Astrophobic** is not always a result of a very dramatic event. ⁇ * Often, it).That must be.*That got lost. ⁇ That got you got.)That must have got you confused. Most doctors agree. Muscles seem to lose flexibility.But steroids can still keep you sharp. Muscle cramps. Muscle cramps. Muscle fatigue. Most adults can still breathe. ⁇ **Most organs have multiple functions.** ⁇ Most treatments are aimed at improving symptoms. But advanced treatments may trigger inflammation. Low maintenance but soft but solid retain or soft core. Presence: Completing the major: Co the specific major"
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"prompt": "<user>Explain what artificial intelligence is in simple words.<assistant>",
|
| 16 |
+
"completion": "A oscillator is a singular and eventually identical, specialized singular, the silencer. However, sometimes, sometimes, sometimes, the particles are almost always and typically surrounded by unusual or unusual characteristics, such as being very agile, such as being very small, very small, and often very dense, sometimes very really really. ⁇ They can also sometimes also sometimes sometimes sometimes. Sometimes, sometimes sometimes sometimes sometimes. Sometimes, sometimes sometimes sometimes sometimes. ⁇ However,).).....)....That.).G...That.K..K..New..New....New..New......These two devices seem quite close. Each has its own advantage. ⁇ But sometimes. But sometimes. It doesn't. But it still does. But... The hypothesis. The hypothesis. The untested or rare. Comes in red. P). Complex. Co. High. CoinIn the general, the"
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"prompt": "<user>Give me three practical study tips.<assistant>",
|
| 20 |
+
"completion": "Dementia, also known as numbness or tingling under the arc (the gait), is a physical deformity. The necessity is the 11111111 and also typically includes an injury or condition that makes it difficult for you to do some simple things, such as getting dressed, getting some rest, or getting some sleep, but not actually getting really tired. Sometimes simply having this sometimes sometimes sometimes does really does sometimes does. Sometimes, sometimes sometimes sometimes. Sometimes sometimes.)It.The....).)...).That...That.. Two. ⁇ .That.. ⁇ Check it. ⁇ Check everything. ⁇ Check it. ⁇ But keep it small. Remember to keep everything hidden. Remember to keep everything organized. Remember to keep everyone informed. Remember to keep valuable items readily accessible. Most businesses don't offer absolute protection. Let the deco stay solid. Cos. Cos. Co. is a specific sub. ⁇ In a Coin,"
|
| 21 |
+
}
|
| 22 |
+
]
|
| 23 |
+
}
|
artifacts/gpu_telemetry_nvidia_smi.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
artifacts/loss_curve_summary.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"pretrain": {
|
| 3 |
+
"points": 961817,
|
| 4 |
+
"first_loss": 9.26,
|
| 5 |
+
"last_loss": 2.64,
|
| 6 |
+
"min_loss": 0.577,
|
| 7 |
+
"max_loss": 9.28,
|
| 8 |
+
"tail20_mean": 2.367,
|
| 9 |
+
"tail100_mean": 2.3350999999999997,
|
| 10 |
+
"tail200_mean": 2.3415500000000002,
|
| 11 |
+
"tail1000_mean": 2.373460000000001,
|
| 12 |
+
"csv_path": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/analysis/pretrain_loss_curve.csv"
|
| 13 |
+
},
|
| 14 |
+
"sft": {
|
| 15 |
+
"points": 49238,
|
| 16 |
+
"first_loss": 3.2,
|
| 17 |
+
"last_loss": 1.08,
|
| 18 |
+
"min_loss": 0.386,
|
| 19 |
+
"max_loss": 3.38,
|
| 20 |
+
"tail20_mean": 0.9390999999999998,
|
| 21 |
+
"tail100_mean": 0.9585200000000004,
|
| 22 |
+
"tail200_mean": 0.965355,
|
| 23 |
+
"tail1000_mean": 0.9520289999999981,
|
| 24 |
+
"csv_path": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/analysis/sft_loss_curve.csv"
|
| 25 |
+
}
|
| 26 |
+
}
|
artifacts/model_card.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"state": "completed",
|
| 3 |
+
"architecture": "taonet_ssm",
|
| 4 |
+
"candidate": "pure_ssm_196m_branch_rms_only",
|
| 5 |
+
"pretrain_checkpoint": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/pretrain/final_model.pt",
|
| 6 |
+
"sft_checkpoint": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/checkpoints/sft/final_model.pt",
|
| 7 |
+
"tokenizer_path": "/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model",
|
| 8 |
+
"pretrain_config": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/configs/pretrain.yaml",
|
| 9 |
+
"sft_config": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/configs/sft.yaml",
|
| 10 |
+
"activation_probe": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/diagnostics/activation_probe_pretrain_final.json",
|
| 11 |
+
"pretrain_samples": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/diagnostics/generation_samples_pretrain_final.json",
|
| 12 |
+
"sft_samples": "/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514/outputs/diagnostics/generation_samples_sft_final.json"
|
| 13 |
+
}
|
artifacts/run.sh
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -Eeuo pipefail
|
| 3 |
+
|
| 4 |
+
job_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
status_json="$job_dir/status.json"
|
| 6 |
+
log_path="$job_dir/train.log"
|
| 7 |
+
telemetry_path="$job_dir/gpu_telemetry_nvidia_smi.csv"
|
| 8 |
+
|
| 9 |
+
write_status() {
|
| 10 |
+
local state="$1"
|
| 11 |
+
local extra="${2:-}"
|
| 12 |
+
local now
|
| 13 |
+
now="$(date -Iseconds)"
|
| 14 |
+
cat > "$status_json" <<STATUS
|
| 15 |
+
{"state":"$state","updated_at":"$now","job_dir":"$job_dir"$extra}
|
| 16 |
+
STATUS
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
write_status "running" ",\"start_time\":\"$(date -Iseconds)\""
|
| 20 |
+
rm -f "$job_dir/DONE" "$job_dir/FAILED"
|
| 21 |
+
|
| 22 |
+
monitor_pid=""
|
| 23 |
+
if command -v nvidia-smi >/dev/null 2>&1; then
|
| 24 |
+
(
|
| 25 |
+
while true; do
|
| 26 |
+
date -Iseconds
|
| 27 |
+
nvidia-smi --query-gpu=timestamp,name,utilization.gpu,utilization.memory,memory.used,memory.total,power.draw,temperature.gpu --format=csv,noheader,nounits
|
| 28 |
+
sleep 5
|
| 29 |
+
done
|
| 30 |
+
) > "$telemetry_path" 2>&1 &
|
| 31 |
+
monitor_pid="$!"
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
cleanup() {
|
| 35 |
+
if [[ -n "$monitor_pid" ]]; then
|
| 36 |
+
kill "$monitor_pid" 2>/dev/null || true
|
| 37 |
+
wait "$monitor_pid" 2>/dev/null || true
|
| 38 |
+
fi
|
| 39 |
+
}
|
| 40 |
+
trap cleanup EXIT
|
| 41 |
+
|
| 42 |
+
set +e
|
| 43 |
+
"$job_dir/command.sh" > "$log_path" 2>&1
|
| 44 |
+
exit_code="$?"
|
| 45 |
+
set -e
|
| 46 |
+
|
| 47 |
+
if [[ "$exit_code" -eq 0 ]]; then
|
| 48 |
+
touch "$job_dir/DONE"
|
| 49 |
+
write_status "completed" ",\"exit_code\":0,\"end_time\":\"$(date -Iseconds)\""
|
| 50 |
+
else
|
| 51 |
+
echo "$exit_code" > "$job_dir/FAILED"
|
| 52 |
+
write_status "failed" ",\"exit_code\":$exit_code,\"end_time\":\"$(date -Iseconds)\""
|
| 53 |
+
fi
|
| 54 |
+
|
| 55 |
+
exit "$exit_code"
|
artifacts/run_plan.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"purpose": "200m_branch_only_pure_ssm_4b_pretrain_plus_sft_chatbot_attempt",
|
| 3 |
+
"candidate": "pure_ssm_196m_branch_rms_only",
|
| 4 |
+
"pretrain_token_positions": 4000000000,
|
| 5 |
+
"pretrain_steps": 976563,
|
| 6 |
+
"batch_size": 8,
|
| 7 |
+
"seq_len": 512,
|
| 8 |
+
"sft_steps": 50000,
|
| 9 |
+
"pretrain_lr": 0.0008,
|
| 10 |
+
"sft_lr": 0.00005,
|
| 11 |
+
"save_every_steps": 100000,
|
| 12 |
+
"block_residual_rms_cap": "null",
|
| 13 |
+
"notes": [
|
| 14 |
+
"Uses streaming JSONL training through TaoTrain CLI.",
|
| 15 |
+
"Counts token positions as batch_size * seq_len * max_steps.",
|
| 16 |
+
"Keeps SSM branch RMS normalization enabled and block residual RMS normalization disabled.",
|
| 17 |
+
"Runs corrected response-only SFT after pretraining."
|
| 18 |
+
]
|
| 19 |
+
}
|
artifacts/status.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"state":"completed","updated_at":"2026-05-15T19:22:45+08:00","job_dir":"/home/student/YouZheng/jobs/taotern/taotern-200m-branch-only-chat-20260514","exit_code":0,"end_time":"2026-05-15T19:22:45+08:00"}
|
code/TaoTrain/.gitignore
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
*.manifest
|
| 32 |
+
*.spec
|
| 33 |
+
|
| 34 |
+
# Installer logs
|
| 35 |
+
pip-log.txt
|
| 36 |
+
pip-delete-this-directory.txt
|
| 37 |
+
|
| 38 |
+
# Unit test / coverage reports
|
| 39 |
+
htmlcov/
|
| 40 |
+
.tox/
|
| 41 |
+
.nox/
|
| 42 |
+
.coverage
|
| 43 |
+
.coverage.*
|
| 44 |
+
.cache
|
| 45 |
+
nosetests.xml
|
| 46 |
+
coverage.xml
|
| 47 |
+
*.cover
|
| 48 |
+
*.py,cover
|
| 49 |
+
.hypothesis/
|
| 50 |
+
.pytest_cache/
|
| 51 |
+
|
| 52 |
+
# Translations
|
| 53 |
+
*.mo
|
| 54 |
+
*.pot
|
| 55 |
+
|
| 56 |
+
# Django stuff:
|
| 57 |
+
*.log
|
| 58 |
+
local_settings.py
|
| 59 |
+
db.sqlite3
|
| 60 |
+
db.sqlite3-journal
|
| 61 |
+
|
| 62 |
+
# Flask stuff:
|
| 63 |
+
instance/
|
| 64 |
+
.webassets-cache
|
| 65 |
+
|
| 66 |
+
# Scrapy stuff:
|
| 67 |
+
.scrapy
|
| 68 |
+
|
| 69 |
+
# Sphinx documentation
|
| 70 |
+
docs/_build/
|
| 71 |
+
|
| 72 |
+
# PyBuilder
|
| 73 |
+
target/
|
| 74 |
+
|
| 75 |
+
# Jupyter Notebook
|
| 76 |
+
.ipynb_checkpoints
|
| 77 |
+
|
| 78 |
+
# IPython
|
| 79 |
+
profile_default/
|
| 80 |
+
ipython_config.py
|
| 81 |
+
|
| 82 |
+
# pyenv
|
| 83 |
+
.python-version
|
| 84 |
+
|
| 85 |
+
# pipenv
|
| 86 |
+
Pipfile.lock
|
| 87 |
+
|
| 88 |
+
# PEP 582
|
| 89 |
+
__pypackages__/
|
| 90 |
+
|
| 91 |
+
# Celery stuff
|
| 92 |
+
celerybeat-schedule
|
| 93 |
+
celerybeat.pid
|
| 94 |
+
|
| 95 |
+
# SageMath parsed files
|
| 96 |
+
*.sage.py
|
| 97 |
+
|
| 98 |
+
# Environments
|
| 99 |
+
.env
|
| 100 |
+
.venv
|
| 101 |
+
env/
|
| 102 |
+
venv/
|
| 103 |
+
ENV/
|
| 104 |
+
env.bak/
|
| 105 |
+
venv.bak/
|
| 106 |
+
|
| 107 |
+
# Spyder project settings
|
| 108 |
+
.spyderproject
|
| 109 |
+
.spyproject
|
| 110 |
+
|
| 111 |
+
# Rope project settings
|
| 112 |
+
.ropeproject
|
| 113 |
+
|
| 114 |
+
# mkdocs documentation
|
| 115 |
+
/site
|
| 116 |
+
|
| 117 |
+
# mypy
|
| 118 |
+
.mypy_cache/
|
| 119 |
+
.dmypy.json
|
| 120 |
+
dmypy.json
|
| 121 |
+
|
| 122 |
+
# Pyre type checker
|
| 123 |
+
.pyre/
|
| 124 |
+
|
| 125 |
+
# IDEs
|
| 126 |
+
.vscode/
|
| 127 |
+
.idea/
|
| 128 |
+
*.swp
|
| 129 |
+
*.swo
|
| 130 |
+
*~
|
| 131 |
+
|
| 132 |
+
# OS
|
| 133 |
+
.DS_Store
|
| 134 |
+
Thumbs.db
|
| 135 |
+
|
| 136 |
+
# Project-specific
|
| 137 |
+
checkpoints/
|
| 138 |
+
runs/
|
| 139 |
+
results/
|
| 140 |
+
*.pt
|
| 141 |
+
*.pth
|
| 142 |
+
*.safetensors
|
| 143 |
+
aim_logs/
|
| 144 |
+
.aim/
|
| 145 |
+
wandb/
|
| 146 |
+
lightning_logs/
|
| 147 |
+
|
| 148 |
+
# Dataset cache
|
| 149 |
+
.cache/
|
| 150 |
+
datasets_cache/
|
code/TaoTrain/README.md
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TaoTrain: Production-Grade LLM Training Framework
|
| 2 |
+
|
| 3 |
+
**TaoTrain** is a sophisticated PyTorch framework for training large language models at every scale—from experimental pretraining through supervised fine-tuning to reinforcement learning. Unlike fragmented training scripts or heavyweight frameworks, TaoTrain unifies the **entire training pipeline** in a clean, modular codebase that appeals to both ML engineers and software engineers.
|
| 4 |
+
|
| 5 |
+
## Current Taotern Work
|
| 6 |
+
|
| 7 |
+
TaoTrain now includes the Taotern comparison architectures used by the current SSM LLM work:
|
| 8 |
+
|
| 9 |
+
- `taonet`: the attention/MLA baseline.
|
| 10 |
+
- `taonet_ssm`: the TaoNet shell with the attention mixer replaced by the Gamma Space Model DPLR SSM.
|
| 11 |
+
- `taonet_hybrid`: an alternating attention/SSM TaoNet used for the current best 200M-class candidate.
|
| 12 |
+
|
| 13 |
+
The current selected deployment-oriented run is `hybrid_ssm_first_199m`, a `199,480,928` parameter model with 16 layers: SSM layers at `0,2,4,6,8,10,12,14` and attention layers at `1,3,5,7,9,11,13,15`. It uses the DPLR SSM core with split two-lane mixing, channel gates, per-channel local shift, and the faster convolution path for long-sequence training.
|
| 14 |
+
|
| 15 |
+
Remote run `taotern-200m-hybrid-chat-20260512` trains this model on TaoData for a 4B-token base stage and then runs SFT so the final artifact can be loaded as a chat model. The trainable fixes added for this run are:
|
| 16 |
+
|
| 17 |
+
- Async JSONL iteration keeps polling while tokenization workers are alive instead of ending early after a temporary empty queue.
|
| 18 |
+
- Cached JSONL scan metadata is reused safely while recomputing chunk ranges for the active `samples_per_chunk` and `max_samples` settings.
|
| 19 |
+
|
| 20 |
+
## Why TaoTrain?
|
| 21 |
+
|
| 22 |
+
- **Complete Unified Pipeline**: Pretraining → SFT → RL in a single, consistent framework. No context switching between different codebases or architectures.
|
| 23 |
+
- **Production-Grade Engineering**: Type-safe Pydantic configs, comprehensive checkpointing, AimStack integration, and proper gradient handling—not research code, but a framework you can deploy.
|
| 24 |
+
- **Extensibility Without Modification**: Register custom models, optimizers, schedulers, and datasets via decorators. Experiment freely without forking the framework.
|
| 25 |
+
- **Developer Experience First**: Interactive TUI for inference, intuitive YAML configurations, async data loading that eliminates I/O bottlenecks, and clear abstractions that make the codebase a pleasure to work with.
|
| 26 |
+
|
| 27 |
+
## Key Capabilities
|
| 28 |
+
|
| 29 |
+
| Capability | Details |
|
| 30 |
+
|---|---|
|
| 31 |
+
| **Multi-Stage Training** | Unified infrastructure for pretraining, SFT, and RL. Share model checkpoints, logging, and evaluation across stages. |
|
| 32 |
+
| **Advanced Optimization** | Hybrid Muon + AdamW optimizer: efficient 2D weight updates via SVD-based methods + adaptive learning for 1D parameters. |
|
| 33 |
+
| **Modern Architectures** | DeepSeek MLA with grouped query attention (GQA), YaRN context extension, and factorized embeddings—all configurable via YAML. |
|
| 34 |
+
| **Production Features** | BF16 mixed precision training, gradient accumulation, proper gradient clipping, checkpoint resumption, and validation loops. |
|
| 35 |
+
| **Async Data Pipeline** | Background tokenization with multi-threaded workers. Stream billion-token datasets from JSONL without loading into memory. |
|
| 36 |
+
| **Interactive Inference** | TUI chat interface with real-time generation speed metrics and multi-model comparison. |
|
| 37 |
+
| **Logging & Monitoring** | AimStack integration tracks loss, metrics, hyperparameters, and git hashes for reproducibility. Visualize training runs in your browser. |
|
| 38 |
+
|
| 39 |
+
## Getting Started
|
| 40 |
+
|
| 41 |
+
### Installation
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
git clone https://github.com/lobakkang/taoTrain.git
|
| 45 |
+
cd taoTrain
|
| 46 |
+
pip install -e .
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
### Training Examples
|
| 50 |
+
|
| 51 |
+
**Pretraining on a custom dataset:**
|
| 52 |
+
```bash
|
| 53 |
+
train pretrain --config configs/pretrain.yaml
|
| 54 |
+
```
|
| 55 |
+
Starts from scratch, learns representations from raw text via next-token prediction.
|
| 56 |
+
|
| 57 |
+
**Supervised Fine-tuning:**
|
| 58 |
+
```bash
|
| 59 |
+
train sft --config configs/sft.yaml
|
| 60 |
+
```
|
| 61 |
+
Fine-tune a pretrained model on instruction-response pairs for improved task performance.
|
| 62 |
+
|
| 63 |
+
**Reinforcement Learning (DPO):**
|
| 64 |
+
```bash
|
| 65 |
+
train rl --config configs/rl_dpo.yaml
|
| 66 |
+
```
|
| 67 |
+
Align models with human preferences using Direct Preference Optimization.
|
| 68 |
+
|
| 69 |
+
**Interactive Chat:**
|
| 70 |
+
```bash
|
| 71 |
+
tui-chat --model checkpoints/model.pt
|
| 72 |
+
```
|
| 73 |
+
Launch an interactive TUI to chat with your model and monitor generation metrics in real-time.
|
| 74 |
+
|
| 75 |
+
### Configuration
|
| 76 |
+
|
| 77 |
+
All training is configured via YAML with Pydantic validation. Configs are type-safe and automatically validated:
|
| 78 |
+
|
| 79 |
+
```yaml
|
| 80 |
+
# configs/sft.yaml
|
| 81 |
+
model:
|
| 82 |
+
architecture_type: "mla" # DeepSeek MLA with GQA
|
| 83 |
+
hidden_dim: 2048
|
| 84 |
+
num_layers: 24
|
| 85 |
+
num_heads: 32
|
| 86 |
+
d_latent_kv: 1536 # KV compression factor
|
| 87 |
+
|
| 88 |
+
training:
|
| 89 |
+
num_epochs: 3
|
| 90 |
+
batch_size: 32
|
| 91 |
+
learning_rate: 1e-4
|
| 92 |
+
warmup_ratio: 0.1
|
| 93 |
+
max_grad_norm: 1.0
|
| 94 |
+
|
| 95 |
+
optimizer:
|
| 96 |
+
optimizer_type: "muon_adamw" # Hybrid Muon + AdamW
|
| 97 |
+
muon_momentum: 0.95
|
| 98 |
+
|
| 99 |
+
data:
|
| 100 |
+
dataset_type: "sft_jsonl" # or "sft_hf" for HuggingFace
|
| 101 |
+
path: "data/sft_training.jsonl"
|
| 102 |
+
|
| 103 |
+
logging:
|
| 104 |
+
log_to_aim: true
|
| 105 |
+
aim_repo: "/tmp/aim_logs"
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
See `configs/` for complete examples.
|
| 109 |
+
|
| 110 |
+
## Project Architecture
|
| 111 |
+
|
| 112 |
+
```
|
| 113 |
+
src/taoTrain/
|
| 114 |
+
├── cli.py # Main CLI entry point
|
| 115 |
+
├── config.py # Pydantic configuration schemas
|
| 116 |
+
│
|
| 117 |
+
├── core/ # Base abstractions
|
| 118 |
+
│ └── base.py # BaseModel, BaseDataset, BaseTrainer
|
| 119 |
+
│
|
| 120 |
+
├── models/ # Pluggable architecture system
|
| 121 |
+
│ ├── registry.py # Architecture factory with @register_architecture
|
| 122 |
+
│ ├── taonet.py # SimpleLLM with DeepSeek MLA
|
| 123 |
+
│ ├── mla_components.py # KV compression, GQA, YaRN
|
| 124 |
+
│ ├── embeddings.py # Factorized embeddings
|
| 125 |
+
│ └── transformer.py # Standard Transformer reference
|
| 126 |
+
│
|
| 127 |
+
├── data/ # Advanced data pipeline
|
| 128 |
+
│ ├── factory.py # Dataset factory (HF + JSONL backends)
|
| 129 |
+
│ ├── async_loader.py # Async batch iteration (no I/O bottleneck)
|
| 130 |
+
│ ├── tokenization_queue.py # Background multi-threaded tokenization
|
| 131 |
+
│ ├── chunk_manager.py # Stream billion-token JSONL files
|
| 132 |
+
│ ├── hf_pretrain.py # HuggingFace pretraining datasets
|
| 133 |
+
│ ├── hf_sft.py # HuggingFace SFT datasets
|
| 134 |
+
│ ├── hf_rl.py # HuggingFace RL datasets
|
| 135 |
+
│ ├── pretrain_jsonl.py # JSONL pretraining
|
| 136 |
+
│ ├── sft_jsonl.py # JSONL SFT with instructions
|
| 137 |
+
│ └── rl_jsonl.py # JSONL RL with preferences
|
| 138 |
+
│
|
| 139 |
+
├── training/ # Unified training infrastructure
|
| 140 |
+
│ └── trainer.py # Trainer + PretrainTrainer, SFTTrainer, RLTrainer
|
| 141 |
+
│
|
| 142 |
+
├── optimizers/ # Pluggable optimizer system
|
| 143 |
+
│ ├── registry.py # Optimizer factory with @register_optimizer
|
| 144 |
+
│ ├── hybrid_muon_adamw.py # Composite: Muon (2D) + AdamW (1D)
|
| 145 |
+
│ ├── adamw.py # AdamW with weight decay
|
| 146 |
+
│ ├── adam.py # Standard Adam
|
| 147 |
+
│ └── sgd.py # SGD variants
|
| 148 |
+
│
|
| 149 |
+
├── schedulers/ # Learning rate schedules
|
| 150 |
+
│ ├── registry.py # LR scheduler factory
|
| 151 |
+
│ ├── cosine_warmup.py # 3-phase: linear warmup → plateau → cosine decay
|
| 152 |
+
│ ├── linear_warmup.py # Linear warmup + constant
|
| 153 |
+
│ └── constant.py # Constant learning rate
|
| 154 |
+
│
|
| 155 |
+
├── inference/ # Inference & interaction
|
| 156 |
+
│ ├── inferencer.py # Load & run inference from checkpoints
|
| 157 |
+
│ └── tui.py # Interactive chat with metrics display
|
| 158 |
+
│
|
| 159 |
+
├── checkpointing/ # State management
|
| 160 |
+
│ └── checkpoint.py # Save/load model + optimizer + config + metrics
|
| 161 |
+
│
|
| 162 |
+
├── logging/ # Experiment tracking
|
| 163 |
+
│ └── aim_logger.py # AimStack integration (loss, metrics, hyperparams)
|
| 164 |
+
│
|
| 165 |
+
├── benchmarks/ # Evaluation tools
|
| 166 |
+
│ └── runner.py # Perplexity, speed, and task-specific benchmarks
|
| 167 |
+
│
|
| 168 |
+
└── utils/
|
| 169 |
+
└── helpers.py # Utility functions
|
| 170 |
+
|
| 171 |
+
configs/ # Example YAML configurations
|
| 172 |
+
├── pretrain.yaml # Pretraining config
|
| 173 |
+
├── sft.yaml # SFT config
|
| 174 |
+
├── rl_dpo.yaml # RL/DPO config
|
| 175 |
+
└── tokenizer.yaml # Tokenizer config
|
| 176 |
+
|
| 177 |
+
tests/ # Unit & integration tests
|
| 178 |
+
└── test_dataset.py
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
## Extensible Architecture: The Registry Pattern
|
| 182 |
+
|
| 183 |
+
TaoTrain's power lies in its **pluggable design**. Add custom models, optimizers, schedulers, and datasets without modifying the framework.
|
| 184 |
+
|
| 185 |
+
### Custom Model Architecture
|
| 186 |
+
|
| 187 |
+
```python
|
| 188 |
+
from taoTrain.models import register_architecture, BaseModel
|
| 189 |
+
import torch.nn as nn
|
| 190 |
+
|
| 191 |
+
@register_architecture("custom_moe")
|
| 192 |
+
class MixtureOfExperts(BaseModel):
|
| 193 |
+
"""Your custom MoE architecture"""
|
| 194 |
+
def __init__(self, config):
|
| 195 |
+
super().__init__(config)
|
| 196 |
+
self.experts = nn.ModuleList([
|
| 197 |
+
nn.Linear(config.hidden_dim, config.hidden_dim)
|
| 198 |
+
for _ in range(config.num_experts)
|
| 199 |
+
])
|
| 200 |
+
self.router = nn.Linear(config.hidden_dim, config.num_experts)
|
| 201 |
+
|
| 202 |
+
def forward(self, input_ids, attention_mask=None):
|
| 203 |
+
# Your implementation
|
| 204 |
+
logits = self.compute_logits(input_ids)
|
| 205 |
+
loss = self.compute_loss(logits, labels) if labels is not None else None
|
| 206 |
+
return {"logits": logits, "loss": loss}
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
Then use it in your config:
|
| 210 |
+
|
| 211 |
+
```yaml
|
| 212 |
+
model:
|
| 213 |
+
architecture_type: "custom_moe"
|
| 214 |
+
hidden_dim: 2048
|
| 215 |
+
num_experts: 8
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
### Custom Optimizers & Schedulers
|
| 219 |
+
|
| 220 |
+
The same pattern works for optimizers and learning rate schedules:
|
| 221 |
+
|
| 222 |
+
```python
|
| 223 |
+
from taoTrain.optimizers import register_optimizer
|
| 224 |
+
from torch.optim import Optimizer
|
| 225 |
+
|
| 226 |
+
@register_optimizer("my_adaptive_optimizer")
|
| 227 |
+
class MyAdaptiveOptimizer(Optimizer):
|
| 228 |
+
def step(self, closure=None):
|
| 229 |
+
# Your optimization logic
|
| 230 |
+
pass
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
```python
|
| 234 |
+
from taoTrain.schedulers import register_scheduler
|
| 235 |
+
|
| 236 |
+
@register_scheduler("my_schedule")
|
| 237 |
+
def my_schedule(initial_lr, step, total_steps, **kwargs):
|
| 238 |
+
return initial_lr * (1.0 - step / total_steps) # Linear decay
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
**The key principle**: No framework code needs to change. You register once, it's available everywhere.
|
| 242 |
+
|
| 243 |
+
### Dataset Backend Flexibility
|
| 244 |
+
|
| 245 |
+
Define custom datasets (JSONL, HF, streaming, etc.) and let the factory route to them:
|
| 246 |
+
|
| 247 |
+
```python
|
| 248 |
+
from taoTrain.data import register_dataset
|
| 249 |
+
|
| 250 |
+
@register_dataset("pretrain", "my_backend")
|
| 251 |
+
class MyPretrainDataset(BaseDataset):
|
| 252 |
+
def __init__(self, config):
|
| 253 |
+
# Load from your custom backend
|
| 254 |
+
pass
|
| 255 |
+
|
| 256 |
+
def __getitem__(self, idx):
|
| 257 |
+
return {"input_ids": ..., "attention_mask": ...}
|
| 258 |
+
```
|
| 259 |
+
|
| 260 |
+
Use in config:
|
| 261 |
+
|
| 262 |
+
```yaml
|
| 263 |
+
data:
|
| 264 |
+
dataset_type: "pretrain"
|
| 265 |
+
backend_type: "my_backend" # Routes to MyPretrainDataset
|
| 266 |
+
```
|
| 267 |
+
|
| 268 |
+
## Why TaoTrain Framework?
|
| 269 |
+
|
| 270 |
+
### Async Data Loading: No I/O Bottleneck
|
| 271 |
+
|
| 272 |
+
Most training frameworks load and tokenize data on the main training thread, blocking compute. TaoTrain's **multi-threaded tokenization pipeline**:
|
| 273 |
+
|
| 274 |
+
- Tokenizes data in background workers while your GPU trains
|
| 275 |
+
- Supports streaming billion-token JSONL files without loading into memory
|
| 276 |
+
- Intelligent chunking (by file size or sample count)
|
| 277 |
+
- Metadata caching to avoid rescanning
|
| 278 |
+
|
| 279 |
+
**Result**: 10-100x faster data iteration on large datasets.
|
| 280 |
+
|
| 281 |
+
### Type-Safe Configuration
|
| 282 |
+
|
| 283 |
+
Forget YAML parsing errors or mysterious config bugs. TaoTrain uses **Pydantic dataclasses** for configuration:
|
| 284 |
+
|
| 285 |
+
- Automatic type validation: mistyped `learning_rate: "1e-4"` becomes an error, not silent failure
|
| 286 |
+
- Serialization: configs are part of checkpoints, ensuring reproducibility
|
| 287 |
+
- IDE support: autocomplete and type hints for all config fields
|
| 288 |
+
- Defaults: sensible defaults for all parameters
|
| 289 |
+
|
| 290 |
+
### Benchmarking & Metrics
|
| 291 |
+
|
| 292 |
+
Track what matters:
|
| 293 |
+
|
| 294 |
+
- **Perplexity**: Language modeling quality on held-out data
|
| 295 |
+
- **Generation Speed**: Tokens-per-second (useful for TUI or deployment)
|
| 296 |
+
- **Task-Specific Accuracy**: Evaluate on downstream tasks
|
| 297 |
+
- **Training Metrics**: Loss curves, gradient norms, effective batch size
|
| 298 |
+
|
| 299 |
+
All logged to AimStack with git hashes for reproducibility.
|
| 300 |
+
|
| 301 |
+
## Logging with AimStack
|
| 302 |
+
|
| 303 |
+
Automatically track and visualize experiments:
|
| 304 |
+
|
| 305 |
+
```bash
|
| 306 |
+
aim up --host 0.0.0.0
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
Then open `http://localhost:43800` to see:
|
| 310 |
+
|
| 311 |
+
- **Loss curves** per training step
|
| 312 |
+
- **Hyperparameters** (learning rate, batch size, model architecture)
|
| 313 |
+
- **Git hashes** for reproducibility
|
| 314 |
+
- **Custom metrics** (perplexity, validation accuracy, generation speed)
|
| 315 |
+
- **Compare runs**: Side-by-side experiment comparison
|
| 316 |
+
|
| 317 |
+
## Advanced Features
|
| 318 |
+
|
| 319 |
+
### Checkpointing with Resumption
|
| 320 |
+
|
| 321 |
+
TaoTrain saves complete training state:
|
| 322 |
+
|
| 323 |
+
```python
|
| 324 |
+
checkpoint = {
|
| 325 |
+
"step": 12500,
|
| 326 |
+
"model_state": model.state_dict(),
|
| 327 |
+
"optimizer_state": optimizer.state_dict(),
|
| 328 |
+
"config": config, # Full config as Pydantic object
|
| 329 |
+
"metrics": metrics_tracker.to_dict(),
|
| 330 |
+
}
|
| 331 |
+
```
|
| 332 |
+
|
| 333 |
+
Resume training from any checkpoint without loss of state. Keep last N checkpoints automatically.
|
| 334 |
+
|
| 335 |
+
### Mixed Precision Training (BF16)
|
| 336 |
+
|
| 337 |
+
```yaml
|
| 338 |
+
training:
|
| 339 |
+
use_bfloat16: true
|
| 340 |
+
gradient_accumulation_steps: 4
|
| 341 |
+
```
|
| 342 |
+
|
| 343 |
+
- BF16 via `torch.autocast` for ~2x speedup with minimal accuracy loss
|
| 344 |
+
- Proper gradient scaling and clipping
|
| 345 |
+
- Compatible with all optimizers and architectures
|
| 346 |
+
|
| 347 |
+
### 3-Phase Learning Rate Schedule
|
| 348 |
+
|
| 349 |
+
```yaml
|
| 350 |
+
scheduler:
|
| 351 |
+
scheduler_type: "cosine_warmup"
|
| 352 |
+
warmup_ratio: 0.1 # 10% of training steps
|
| 353 |
+
steady_ratio: 0.5 # 50% at steady rate
|
| 354 |
+
min_lr_ratio: 0.1 # Final LR = 0.1 × initial_lr
|
| 355 |
+
num_cycles: 1
|
| 356 |
+
```
|
| 357 |
+
|
| 358 |
+
This schedule:
|
| 359 |
+
1. **Linear warmup** (0 → 1) over 10% of steps
|
| 360 |
+
2. **Steady plateau** at full LR over 50% of steps
|
| 361 |
+
3. **Cosine decay** (1 → 0.1) over remaining 40% of steps
|
| 362 |
+
|
| 363 |
+
Better convergence than simple cosine or linear decay.
|
| 364 |
+
|
| 365 |
+
### Gradient Accumulation & Clipping
|
| 366 |
+
|
| 367 |
+
Simulate larger batch sizes with gradient accumulation:
|
| 368 |
+
|
| 369 |
+
```yaml
|
| 370 |
+
training:
|
| 371 |
+
batch_size: 32
|
| 372 |
+
gradient_accumulation_steps: 4 # Effective batch = 128
|
| 373 |
+
max_grad_norm: 1.0 # Gradient clipping
|
| 374 |
+
```
|
| 375 |
+
|
| 376 |
+
## Contributing
|
| 377 |
+
|
| 378 |
+
Contributions are welcome! TaoTrain is designed to make contributions easy:
|
| 379 |
+
|
| 380 |
+
1. **Add a model**: Implement `BaseModel` and `@register_architecture("name")`
|
| 381 |
+
2. **Add an optimizer**: Implement `torch.optim.Optimizer` and `@register_optimizer("name")`
|
| 382 |
+
3. **Add a dataset**: Implement `BaseDataset` and `@register_dataset(mode, backend_type)`
|
| 383 |
+
4. **Improve the core**: Submit PRs to `training/`, `data/`, `logging/`, etc.
|
| 384 |
+
|
| 385 |
+
Ensure new code includes:
|
| 386 |
+
- Type hints throughout
|
| 387 |
+
- Pydantic configs for new parameters
|
| 388 |
+
- Unit tests in `tests/`
|
| 389 |
+
- Documentation in docstrings and README
|
| 390 |
+
|
| 391 |
+
## Current Scope & Roadmap
|
| 392 |
+
|
| 393 |
+
### ✅ Currently Supported
|
| 394 |
+
|
| 395 |
+
- **Single GPU / single node** training
|
| 396 |
+
- **Pretraining, SFT, and RL training** stages
|
| 397 |
+
- **HuggingFace and JSONL** data backends
|
| 398 |
+
- **BF16 mixed precision** training
|
| 399 |
+
- **Checkpoint saving/loading** with resumption
|
| 400 |
+
- **Interactive inference** via TUI
|
| 401 |
+
- **Benchmarking** (perplexity, speed)
|
| 402 |
+
- **Pluggable architectures, optimizers, schedulers, datasets**
|
| 403 |
+
|
| 404 |
+
### 🚀 Roadmap (Future)
|
| 405 |
+
|
| 406 |
+
- **Distributed training** (DDP, FSDP) for multi-GPU/multi-node scaling
|
| 407 |
+
- **Quantization** support (INT8, QLoRA)
|
| 408 |
+
- **Advanced evaluation** (BLEU, ROUGE, custom tasks)
|
| 409 |
+
- **Streaming inference** with KV cache
|
| 410 |
+
- **Speculative decoding** for faster generation
|
| 411 |
+
- **Integration with popular model hubs** (Hugging Face Hub upload/download)
|
| 412 |
+
|
| 413 |
+
---
|
| 414 |
+
|
| 415 |
+
## Getting Help
|
| 416 |
+
|
| 417 |
+
- **Questions?** Open an issue on GitHub
|
| 418 |
+
- **Want to contribute?** See `CONTRIBUTING.md` (coming soon)
|
| 419 |
+
- **Found a bug?** Report it with a minimal reproduction script
|
| 420 |
+
|
| 421 |
+
## License
|
| 422 |
+
|
| 423 |
+
MIT
|
code/TaoTrain/configs/pretrain.yaml
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TaoNet T2 Configuration for Pretraining
|
| 2 |
+
# DeepSeek MLA + RoPE with Hybrid Muon+AdamW Optimizer
|
| 3 |
+
# Full BF16 precision (no quantization)
|
| 4 |
+
|
| 5 |
+
# ============================================================================
|
| 6 |
+
# Model Architecture - TaoNet (DeepSeek MLA + RoPE)
|
| 7 |
+
# ============================================================================
|
| 8 |
+
model:
|
| 9 |
+
architecture_type: taonet
|
| 10 |
+
vocab_size: 8192
|
| 11 |
+
hidden_dim: 768
|
| 12 |
+
num_layers: 12
|
| 13 |
+
num_heads: 8
|
| 14 |
+
max_seq_length: 1024
|
| 15 |
+
|
| 16 |
+
# TaoNet-specific: Multi-head Latent Attention (MLA) compression
|
| 17 |
+
d_latent_kv: 512
|
| 18 |
+
|
| 19 |
+
# RoPE (Rotary Position Embedding) dimension per head
|
| 20 |
+
# Default would be 512 / 8 = 64
|
| 21 |
+
d_rope: 64
|
| 22 |
+
|
| 23 |
+
# Feed-forward intermediate dimension
|
| 24 |
+
hidden_dim_ff: 2048
|
| 25 |
+
|
| 26 |
+
# Dropout rate (low for stability with large models)
|
| 27 |
+
dropout: 0.02
|
| 28 |
+
|
| 29 |
+
# Grouped Query Attention (1 = standard MLA, >1 = GQA)
|
| 30 |
+
gqa_groups: 1
|
| 31 |
+
|
| 32 |
+
# Optional: Use factorized embedding for parameter efficiency
|
| 33 |
+
# vocab (8192) → rank (96) → hidden (512)
|
| 34 |
+
use_factorized_embedding: false
|
| 35 |
+
d_embed_rank: 96
|
| 36 |
+
|
| 37 |
+
# Weight initialization standard deviation
|
| 38 |
+
init_std: 0.02
|
| 39 |
+
|
| 40 |
+
# ============================================================================
|
| 41 |
+
# Dataset Configuration - Local JSONL
|
| 42 |
+
# ============================================================================
|
| 43 |
+
dataset:
|
| 44 |
+
local: true
|
| 45 |
+
jsonl_path: /home/student/Data/TaoData/output.jsonl
|
| 46 |
+
text_field: text
|
| 47 |
+
max_samples: 1000000
|
| 48 |
+
samples_per_chunk: 1000
|
| 49 |
+
|
| 50 |
+
# Tokenizer configuration
|
| 51 |
+
tokenizer_type: sentencepiece
|
| 52 |
+
tokenizer_path: tokenizer/tokenizer.model
|
| 53 |
+
tokenizer_threads: 4
|
| 54 |
+
|
| 55 |
+
# ============================================================================
|
| 56 |
+
# Training Hyperparameters
|
| 57 |
+
# ============================================================================
|
| 58 |
+
batch_size: 32
|
| 59 |
+
num_epochs: 2 # Set to 10 for full training
|
| 60 |
+
gradient_accumulation_steps: 8 # Effective batch: 32 × 8 = 256
|
| 61 |
+
|
| 62 |
+
# Maximum gradient norm for clipping (prevents ternary instability)
|
| 63 |
+
max_grad_norm: 1.0
|
| 64 |
+
|
| 65 |
+
# ============================================================================
|
| 66 |
+
# Optimizer - Hybrid Muon + AdamW
|
| 67 |
+
# ============================================================================
|
| 68 |
+
# Strategy:
|
| 69 |
+
# - Muon: For 2D Linear weight matrices (orthogonal/SVD-based optimization)
|
| 70 |
+
# - 2D weights: learning_rate (5e-3)
|
| 71 |
+
# - AdamW: For 1D parameters (biases, norms, embeddings)
|
| 72 |
+
# - 1D params: adamw_lr (5e-4) = 1/10 × learning_rate
|
| 73 |
+
|
| 74 |
+
optimizer:
|
| 75 |
+
optimizer_type: hybrid_muon_adamw
|
| 76 |
+
|
| 77 |
+
# Learning rate for Muon (2D weight matrices)
|
| 78 |
+
learning_rate: 5e-3
|
| 79 |
+
|
| 80 |
+
# Learning rate for AdamW (1D parameters)
|
| 81 |
+
# Typically 1/10 of learning_rate to prevent over-updating 1D params
|
| 82 |
+
adamw_lr: 5e-4
|
| 83 |
+
|
| 84 |
+
# L2 regularization (weight decay)
|
| 85 |
+
weight_decay: 0.01
|
| 86 |
+
|
| 87 |
+
# Adam betas
|
| 88 |
+
betas: [0.9, 0.999]
|
| 89 |
+
|
| 90 |
+
# Epsilon for numerical stability
|
| 91 |
+
eps: 1e-8
|
| 92 |
+
|
| 93 |
+
# ============================================================================
|
| 94 |
+
# Learning Rate Scheduler - 3-Phase Cosine with Warmup
|
| 95 |
+
# ============================================================================
|
| 96 |
+
# Phases:
|
| 97 |
+
# 1. Warmup: 0 → 1.0 (300 steps, ~1.4% of training)
|
| 98 |
+
# 2. Steady: 1.0 (constant for 5% of training)
|
| 99 |
+
# 3. Decay: 1.0 → 0.1 (cosine decay for remaining 95%)
|
| 100 |
+
|
| 101 |
+
scheduler:
|
| 102 |
+
scheduler_type: cosineWarmup
|
| 103 |
+
warmup_steps: 300
|
| 104 |
+
warmup_ratio: 0.0 # Ignored if warmup_steps > 0
|
| 105 |
+
steady_ratio: 0.05 # 5% of total training steps at peak LR
|
| 106 |
+
min_lr_ratio: 0.1 # Decay to 10% of peak LR
|
| 107 |
+
num_cycles: 0.5 # For compatibility (not used in 3-phase schedule)
|
| 108 |
+
|
| 109 |
+
# ============================================================================
|
| 110 |
+
# Data Type and Device
|
| 111 |
+
# ============================================================================
|
| 112 |
+
dtype: bfloat16 # Use BF16 for better convergence with large models
|
| 113 |
+
device: cuda # Use GPU for training
|
| 114 |
+
|
| 115 |
+
# ============================================================================
|
| 116 |
+
# Checkpointing and Validation
|
| 117 |
+
# ============================================================================
|
| 118 |
+
checkpoint_dir: checkpoints/test
|
| 119 |
+
save_every_steps: 81920
|
| 120 |
+
save_best_model: true
|
| 121 |
+
keep_last_n_checkpoints: 3
|
| 122 |
+
|
| 123 |
+
# Validation
|
| 124 |
+
eval_every_steps: 8192
|
| 125 |
+
eval_samples: 8000
|
| 126 |
+
|
| 127 |
+
# ============================================================================
|
| 128 |
+
# Logging
|
| 129 |
+
# ============================================================================
|
| 130 |
+
log_every_steps: 50
|
| 131 |
+
aim_repo: .aim
|
| 132 |
+
|
| 133 |
+
# ============================================================================
|
| 134 |
+
# Miscellaneous
|
| 135 |
+
# ============================================================================
|
| 136 |
+
seed: 42
|
| 137 |
+
num_workers: 0
|
| 138 |
+
pin_memory: true
|
code/TaoTrain/configs/rl_dpo.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Example configuration for RL training (RL stage assumes you have a reward model)
|
| 2 |
+
|
| 3 |
+
model:
|
| 4 |
+
architecture_type: transformer
|
| 5 |
+
vocab_size: 50257
|
| 6 |
+
hidden_dim: 256
|
| 7 |
+
num_layers: 8
|
| 8 |
+
num_heads: 8
|
| 9 |
+
dropout: 0.1
|
| 10 |
+
max_seq_length: 512
|
| 11 |
+
init_std: 0.02
|
| 12 |
+
|
| 13 |
+
dataset:
|
| 14 |
+
dataset_name: allenai/real_toxicity_prompts
|
| 15 |
+
split: train
|
| 16 |
+
prompt_column: text
|
| 17 |
+
max_samples: 2000
|
| 18 |
+
cache_dir: .cache/datasets
|
| 19 |
+
tokenizer_threads: 1 # Number of background threads for tokenization (1-32 recommended)
|
| 20 |
+
|
| 21 |
+
batch_size: 4
|
| 22 |
+
num_epochs: 1
|
| 23 |
+
gradient_accumulation_steps: 8
|
| 24 |
+
max_grad_norm: 0.5
|
| 25 |
+
|
| 26 |
+
optimizer:
|
| 27 |
+
optimizer_type: adamw
|
| 28 |
+
learning_rate: 1e-5
|
| 29 |
+
weight_decay: 0.0
|
| 30 |
+
|
| 31 |
+
scheduler:
|
| 32 |
+
scheduler_type: linearWarmup
|
| 33 |
+
warmup_steps: 50
|
| 34 |
+
|
| 35 |
+
dtype: bfloat16
|
| 36 |
+
device: cuda
|
| 37 |
+
|
| 38 |
+
checkpoint_dir: checkpoints/rl
|
| 39 |
+
save_every_steps: 100
|
| 40 |
+
save_best_model: false
|
| 41 |
+
keep_last_n_checkpoints: 2
|
| 42 |
+
|
| 43 |
+
eval_every_steps: 100
|
| 44 |
+
eval_samples: 100
|
| 45 |
+
|
| 46 |
+
log_every_steps: 10
|
| 47 |
+
aim_repo: .aim
|
| 48 |
+
|
| 49 |
+
# RL-specific settings
|
| 50 |
+
rl_method: ppo # or "dpo"
|
| 51 |
+
reward_model_path: checkpoints/reward_model.pt # Path to your reward model
|
| 52 |
+
ppo_epochs: 4
|
| 53 |
+
ppo_clip_ratio: 0.2
|
| 54 |
+
entropy_coeff: 0.01
|
| 55 |
+
value_loss_coeff: 1.0
|
| 56 |
+
generation_max_length: 256
|
| 57 |
+
|
| 58 |
+
seed: 42
|
| 59 |
+
num_workers: 0
|
| 60 |
+
pin_memory: true
|
code/TaoTrain/configs/sft.yaml
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Example configuration for supervised fine-tuning
|
| 2 |
+
# Uses TaoNet (MLA+RoPE) architecture loaded from pretrained checkpoint
|
| 3 |
+
|
| 4 |
+
# ============================================================================
|
| 5 |
+
# Model Architecture - TaoNet (DeepSeek MLA + RoPE)
|
| 6 |
+
# ============================================================================
|
| 7 |
+
model:
|
| 8 |
+
architecture_type: taonet
|
| 9 |
+
vocab_size: 8192
|
| 10 |
+
hidden_dim: 768
|
| 11 |
+
num_layers: 12
|
| 12 |
+
num_heads: 8
|
| 13 |
+
max_seq_length: 1024
|
| 14 |
+
|
| 15 |
+
# TaoNet-specific: Multi-head Latent Attention (MLA) compression
|
| 16 |
+
d_latent_kv: 512
|
| 17 |
+
|
| 18 |
+
# RoPE (Rotary Position Embedding) dimension per head
|
| 19 |
+
# Default would be 512 / 8 = 64
|
| 20 |
+
d_rope: 64
|
| 21 |
+
|
| 22 |
+
# Feed-forward intermediate dimension
|
| 23 |
+
hidden_dim_ff: 2048
|
| 24 |
+
|
| 25 |
+
# Dropout rate (low for stability with large models)
|
| 26 |
+
dropout: 0.02
|
| 27 |
+
|
| 28 |
+
# Grouped Query Attention (1 = standard MLA, >1 = GQA)
|
| 29 |
+
gqa_groups: 1
|
| 30 |
+
|
| 31 |
+
# Optional: Use factorized embedding for parameter efficiency
|
| 32 |
+
# vocab (8192) → rank (96) → hidden (512)
|
| 33 |
+
use_factorized_embedding: false
|
| 34 |
+
d_embed_rank: 96
|
| 35 |
+
|
| 36 |
+
# Weight initialization standard deviation
|
| 37 |
+
init_std: 0.02
|
| 38 |
+
|
| 39 |
+
dataset:
|
| 40 |
+
split: train
|
| 41 |
+
instruction_column: input
|
| 42 |
+
response_column: output
|
| 43 |
+
|
| 44 |
+
local: true
|
| 45 |
+
jsonl_path: /home/student/Data/TaoData/sft.jsonl
|
| 46 |
+
samples_per_chunk: 1000
|
| 47 |
+
#max_samples: 2000000
|
| 48 |
+
max_samples: 1000000
|
| 49 |
+
cache_dir: .cache/datasets
|
| 50 |
+
instruction_template: "{instruction}\n{response}"
|
| 51 |
+
|
| 52 |
+
# Tokenizer configuration
|
| 53 |
+
tokenizer_type: sentencepiece
|
| 54 |
+
tokenizer_path: tokenizer/tokenizer.model
|
| 55 |
+
tokenizer_threads: 4
|
| 56 |
+
|
| 57 |
+
# SFT-specific configuration (these fields are in SFTConfig)
|
| 58 |
+
checkpoint_path: "checkpoints/pretrain/final_model.pt"
|
| 59 |
+
user_token: "<user>"
|
| 60 |
+
assistant_token: "<assistant>"
|
| 61 |
+
response_loss_only: true
|
| 62 |
+
|
| 63 |
+
batch_size: 8
|
| 64 |
+
num_epochs: 1
|
| 65 |
+
gradient_accumulation_steps: 4
|
| 66 |
+
max_grad_norm: 1.0
|
| 67 |
+
|
| 68 |
+
optimizer:
|
| 69 |
+
optimizer_type: adamw
|
| 70 |
+
learning_rate: 5e-5 # Lower LR for fine-tuning (vs 5e-4 pretrain base, 5e-3 Muon)
|
| 71 |
+
weight_decay: 0.01
|
| 72 |
+
|
| 73 |
+
scheduler:
|
| 74 |
+
scheduler_type: linearWarmup
|
| 75 |
+
warmup_steps: 500 # Less aggressive warmup for fine-tuning
|
| 76 |
+
|
| 77 |
+
dtype: bfloat16
|
| 78 |
+
device: cuda
|
| 79 |
+
|
| 80 |
+
checkpoint_dir: checkpoints/sft
|
| 81 |
+
save_every_steps: 81920
|
| 82 |
+
save_best_model: true
|
| 83 |
+
keep_last_n_checkpoints: 2
|
| 84 |
+
|
| 85 |
+
eval_every_steps: 8192
|
| 86 |
+
eval_samples: 200
|
| 87 |
+
|
| 88 |
+
log_every_steps: 10
|
| 89 |
+
aim_repo: .aim
|
| 90 |
+
|
| 91 |
+
seed: 42
|
| 92 |
+
num_workers: 0
|
| 93 |
+
pin_memory: true
|
code/TaoTrain/configs/ssm_pretrain.yaml
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TaoNet-SSM T2 Configuration for Pretraining
|
| 2 |
+
# TaoNet shell with Gamma SSM replacing the MLA attention mixer.
|
| 3 |
+
|
| 4 |
+
model:
|
| 5 |
+
architecture_type: taonet_ssm
|
| 6 |
+
vocab_size: 8192
|
| 7 |
+
hidden_dim: 768
|
| 8 |
+
num_layers: 12
|
| 9 |
+
num_heads: 8
|
| 10 |
+
max_seq_length: 1024
|
| 11 |
+
|
| 12 |
+
d_latent_kv: 512
|
| 13 |
+
d_rope: 64
|
| 14 |
+
hidden_dim_ff: 2048
|
| 15 |
+
dropout: 0.02
|
| 16 |
+
gqa_groups: 1
|
| 17 |
+
use_factorized_embedding: false
|
| 18 |
+
d_embed_rank: 96
|
| 19 |
+
init_std: 0.02
|
| 20 |
+
|
| 21 |
+
ssm_core: dplr
|
| 22 |
+
ssm_hidden_dim: 512
|
| 23 |
+
ssm_mixer_dim: 256
|
| 24 |
+
ssm_rank: 1
|
| 25 |
+
ssm_max_low_rank_scale: 0.1
|
| 26 |
+
ssm_discretization: bilinear
|
| 27 |
+
ssm_kernel_mode: auto
|
| 28 |
+
ssm_kernel_threshold: 64
|
| 29 |
+
ssm_dt_min: 1e-3
|
| 30 |
+
ssm_dt_max: 1e-1
|
| 31 |
+
ssm_dt_init: 1e-2
|
| 32 |
+
ssm_use_d: true
|
| 33 |
+
ssm_activation: gelu
|
| 34 |
+
ssm_gate: true
|
| 35 |
+
ssm_input_gate: true
|
| 36 |
+
ssm_use_padding_mask: false
|
| 37 |
+
ssm_layer_scale_init: 0.1
|
| 38 |
+
|
| 39 |
+
dataset:
|
| 40 |
+
local: true
|
| 41 |
+
jsonl_path: /home/student/Data/TaoData/output.jsonl
|
| 42 |
+
text_field: text
|
| 43 |
+
max_samples: 1000000
|
| 44 |
+
samples_per_chunk: 1000
|
| 45 |
+
|
| 46 |
+
tokenizer_type: sentencepiece
|
| 47 |
+
tokenizer_path: tokenizer/tokenizer.model
|
| 48 |
+
tokenizer_threads: 4
|
| 49 |
+
|
| 50 |
+
batch_size: 32
|
| 51 |
+
num_epochs: 2
|
| 52 |
+
gradient_accumulation_steps: 8
|
| 53 |
+
max_grad_norm: 1.0
|
| 54 |
+
|
| 55 |
+
optimizer:
|
| 56 |
+
optimizer_type: hybrid_muon_adamw
|
| 57 |
+
learning_rate: 5e-3
|
| 58 |
+
adamw_lr: 5e-4
|
| 59 |
+
weight_decay: 0.01
|
| 60 |
+
betas: [0.9, 0.999]
|
| 61 |
+
eps: 1e-8
|
| 62 |
+
|
| 63 |
+
scheduler:
|
| 64 |
+
scheduler_type: cosineWarmup
|
| 65 |
+
warmup_steps: 300
|
| 66 |
+
warmup_ratio: 0.0
|
| 67 |
+
steady_ratio: 0.05
|
| 68 |
+
min_lr_ratio: 0.1
|
| 69 |
+
num_cycles: 0.5
|
| 70 |
+
|
| 71 |
+
dtype: bfloat16
|
| 72 |
+
device: cuda
|
| 73 |
+
|
| 74 |
+
checkpoint_dir: checkpoints/ssm_test
|
| 75 |
+
save_every_steps: 81920
|
| 76 |
+
save_best_model: true
|
| 77 |
+
keep_last_n_checkpoints: 3
|
| 78 |
+
|
| 79 |
+
eval_every_steps: 8192
|
| 80 |
+
eval_samples: 8000
|
| 81 |
+
|
| 82 |
+
log_every_steps: 50
|
| 83 |
+
aim_repo: .aim
|
| 84 |
+
|
| 85 |
+
seed: 42
|
| 86 |
+
num_workers: 0
|
| 87 |
+
pin_memory: true
|
code/TaoTrain/configs/tokenizer.yaml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Example configuration for training a SentencePiece tokenizer from JSONL data
|
| 2 |
+
|
| 3 |
+
# Dataset source - JSONL file
|
| 4 |
+
jsonl_path: /home/student/Data/TaoData/output.jsonl
|
| 5 |
+
text_field: text # Field name in JSON for text data
|
| 6 |
+
|
| 7 |
+
# Tokenizer training parameters
|
| 8 |
+
vocab_size: 8192
|
| 9 |
+
model_type: unigram # SentencePiece model type: unigram, bpe, char, word
|
| 10 |
+
character_coverage: 0.9995
|
| 11 |
+
|
| 12 |
+
# Output configuration
|
| 13 |
+
output_dir: tokenizer
|
| 14 |
+
tokenizer_prefix: tokenizer
|
| 15 |
+
|
| 16 |
+
# Token ID configuration
|
| 17 |
+
unk_id: 0 # Unknown token ID
|
| 18 |
+
bos_id: 1 # Beginning of sentence token ID
|
| 19 |
+
eos_id: 2 # End of sentence token ID
|
| 20 |
+
pad_id: 3 # Padding token ID
|
| 21 |
+
|
| 22 |
+
# Custom special tokens
|
| 23 |
+
# These will be added to the vocabulary with explicit IDs
|
| 24 |
+
# Useful for control tokens like <think>, <user>, <assistant>, etc.
|
| 25 |
+
# Note: Use \n for newline token, \t for tab, etc.
|
| 26 |
+
special_tokens:
|
| 27 |
+
<PAD>: 3 # Padding (typically same as pad_id above)
|
| 28 |
+
<EOS>: 2 # End of sentence (typically same as eos_id above)
|
| 29 |
+
<BOS>: 1 # Beginning of sentence (typically same as bos_id above)
|
| 30 |
+
<UNK>: 0 # Unknown (typically same as unk_id above)
|
| 31 |
+
"\n": 4 # Newline token - quoted to preserve literal \n in YAML
|
| 32 |
+
<think>: 8 # Special token for chain-of-thought reasoning
|
| 33 |
+
<user>: 9 # User message token
|
| 34 |
+
<assistant>: 10 # Assistant message token
|
| 35 |
+
<image>: 11 # Image token for multimodal models
|
| 36 |
+
|
| 37 |
+
# Data sampling (optional)
|
| 38 |
+
# Set to a number to train on only the first N samples from the JSONL file
|
| 39 |
+
# Useful for quick testing or sub-sampling large datasets
|
| 40 |
+
# Omit or set to null to use entire file
|
| 41 |
+
max_samples: 1000000
|
| 42 |
+
|
| 43 |
+
# Optional metadata
|
| 44 |
+
tokenizer_name: tokenizer
|
code/TaoTrain/configs/tokenizer_taodata_pilot.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pilot SentencePiece tokenizer for remote TaoData benchmarks.
|
| 2 |
+
#
|
| 3 |
+
# This is intentionally smaller than full production tokenizer training so the
|
| 4 |
+
# model-comparison loop can validate attention vs SSM on realistic tokenization.
|
| 5 |
+
|
| 6 |
+
jsonl_path: /home/student/Data/TaoData/pretrain.jsonl.fineweb.jsonl
|
| 7 |
+
text_field: text
|
| 8 |
+
|
| 9 |
+
vocab_size: 8192
|
| 10 |
+
model_type: unigram
|
| 11 |
+
character_coverage: 0.9995
|
| 12 |
+
|
| 13 |
+
output_dir: /home/student/YouZheng/tokenizers/taodata_pilot_8k
|
| 14 |
+
tokenizer_prefix: tokenizer
|
| 15 |
+
|
| 16 |
+
unk_id: 0
|
| 17 |
+
bos_id: 1
|
| 18 |
+
eos_id: 2
|
| 19 |
+
pad_id: 3
|
| 20 |
+
|
| 21 |
+
max_samples: 20000
|
| 22 |
+
tokenizer_name: taodata_pilot_8k
|
code/TaoTrain/configs/yarn_pretrain.yaml
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TaoNet T2 Configuration for YaRN Continued Pretraining
|
| 2 |
+
# Extended Context: 1024 → 8192 tokens with frequency interpolation
|
| 3 |
+
# Built on DeepSeek MLA + RoPE with Hybrid Muon+AdamW Optimizer
|
| 4 |
+
# Full BF16 precision (no quantization)
|
| 5 |
+
|
| 6 |
+
# ============================================================================
|
| 7 |
+
# Model Architecture - TaoNet (DeepSeek MLA + RoPE) with YaRN Extension
|
| 8 |
+
# ============================================================================
|
| 9 |
+
model:
|
| 10 |
+
architecture_type: taonet
|
| 11 |
+
vocab_size: 8192
|
| 12 |
+
hidden_dim: 512
|
| 13 |
+
num_layers: 12
|
| 14 |
+
num_heads: 8
|
| 15 |
+
max_seq_length: 8192 # Extended from 1024 → 8192 (8x longer context)
|
| 16 |
+
|
| 17 |
+
# TaoNet-specific: Multi-head Latent Attention (MLA) compression
|
| 18 |
+
# KV dimension reduced from 512 to 384 (25% memory savings)
|
| 19 |
+
d_latent_kv: 384
|
| 20 |
+
|
| 21 |
+
# RoPE (Rotary Position Embedding) dimension per head
|
| 22 |
+
# Default would be 512 / 8 = 64
|
| 23 |
+
d_rope: 64
|
| 24 |
+
|
| 25 |
+
# Feed-forward intermediate dimension
|
| 26 |
+
# Default would be 4 * 512 = 2048
|
| 27 |
+
hidden_dim_ff: 1024
|
| 28 |
+
|
| 29 |
+
# Dropout rate (low for stability)
|
| 30 |
+
dropout: 0.02
|
| 31 |
+
|
| 32 |
+
# Grouped Query Attention (1 = standard MLA, >1 = GQA)
|
| 33 |
+
gqa_groups: 1
|
| 34 |
+
|
| 35 |
+
# Optional: Use factorized embedding for parameter efficiency
|
| 36 |
+
use_factorized_embedding: false
|
| 37 |
+
d_embed_rank: 96
|
| 38 |
+
|
| 39 |
+
# Weight initialization standard deviation
|
| 40 |
+
init_std: 0.02
|
| 41 |
+
|
| 42 |
+
# ========================================================================
|
| 43 |
+
# YaRN (Yet another RoPE eXtension) Configuration
|
| 44 |
+
# ========================================================================
|
| 45 |
+
# Enables frequency interpolation to extend context length from 1024 → 8192
|
| 46 |
+
# The model learns to "pack" RoPE frequencies into the new longer context during training.
|
| 47 |
+
|
| 48 |
+
# RoPE base scale factor (explicit, previously hardcoded to 40)
|
| 49 |
+
rope_scale: 40.0
|
| 50 |
+
|
| 51 |
+
# Enable YaRN frequency interpolation
|
| 52 |
+
yarn_enabled: true
|
| 53 |
+
|
| 54 |
+
# Interpolation smoothness parameter
|
| 55 |
+
# - 1.0 (default): Smooth, gradual interpolation—safer for learning extended context
|
| 56 |
+
# - 0.5: Aggressive interpolation—faster context expansion, higher risk
|
| 57 |
+
# - 2.0: Conservative interpolation—safer but slower adaptation
|
| 58 |
+
# Recommendation: Start with 1.0; tune in follow-up runs if convergence issues
|
| 59 |
+
yarn_alpha: 1.0
|
| 60 |
+
|
| 61 |
+
# ============================================================================
|
| 62 |
+
# Dataset Configuration - Local JSONL (Same as Pretrain)
|
| 63 |
+
# ============================================================================
|
| 64 |
+
dataset:
|
| 65 |
+
local: true
|
| 66 |
+
jsonl_path: /home/student/Data/TaoData/output.jsonl
|
| 67 |
+
text_field: text
|
| 68 |
+
max_samples: 50000 # Reduced from 1M → 50k for quick YaRN adaptation
|
| 69 |
+
samples_per_chunk: 1000
|
| 70 |
+
|
| 71 |
+
# Tokenizer configuration (unchanged)
|
| 72 |
+
tokenizer_type: sentencepiece
|
| 73 |
+
tokenizer_path: tokenizer/tokenizer.model
|
| 74 |
+
tokenizer_threads: 4
|
| 75 |
+
|
| 76 |
+
# ============================================================================
|
| 77 |
+
# Training Hyperparameters - Conservative for Context Extension
|
| 78 |
+
# ============================================================================
|
| 79 |
+
# Strategy: Lower learning rates + smaller batch to prevent catastrophic forgetting
|
| 80 |
+
# while the model learns to use 8x longer context.
|
| 81 |
+
|
| 82 |
+
batch_size: 16 # Reduced from 32 (8192 tokens/seq is memory-intensive)
|
| 83 |
+
num_epochs: 1 # 50k samples / effective_batch=256 ≈ 200 updates (1 epoch sufficient for warm-start)
|
| 84 |
+
|
| 85 |
+
# Gradient accumulation to maintain effective batch size of ~256
|
| 86 |
+
# Effective batch = batch_size × gradient_accumulation_steps = 16 × 16 = 256
|
| 87 |
+
gradient_accumulation_steps: 16
|
| 88 |
+
|
| 89 |
+
# Maximum gradient norm for clipping
|
| 90 |
+
max_grad_norm: 1.0
|
| 91 |
+
|
| 92 |
+
# ============================================================================
|
| 93 |
+
# Optimizer - Hybrid Muon + AdamW (Conservative LR for Stability)
|
| 94 |
+
# ============================================================================
|
| 95 |
+
# Strategy: Use 1/2 of pretrain learning rates to:
|
| 96 |
+
# 1. Avoid catastrophic forgetting of learned features
|
| 97 |
+
# 2. Allow smooth adaptation to YaRN-scaled RoPE frequencies
|
| 98 |
+
# 3. Give the model time to learn how to use extended context
|
| 99 |
+
|
| 100 |
+
optimizer:
|
| 101 |
+
optimizer_type: hybrid_muon_adamw
|
| 102 |
+
|
| 103 |
+
# Learning rate for Muon (2D weight matrices)
|
| 104 |
+
# Reduced: 5e-3 → 2.5e-3 (50% of pretrain)
|
| 105 |
+
learning_rate: 2.5e-3
|
| 106 |
+
|
| 107 |
+
# Learning rate for AdamW (1D parameters)
|
| 108 |
+
# Reduced: 5e-4 → 1.25e-4 (25% of pretrain, maintains 1/10 ratio)
|
| 109 |
+
adamw_lr: 1.25e-4
|
| 110 |
+
|
| 111 |
+
# L2 regularization (weight decay)
|
| 112 |
+
weight_decay: 0.01
|
| 113 |
+
|
| 114 |
+
# Adam betas (unchanged)
|
| 115 |
+
betas: [0.9, 0.999]
|
| 116 |
+
|
| 117 |
+
# Epsilon for numerical stability
|
| 118 |
+
eps: 1e-8
|
| 119 |
+
|
| 120 |
+
# ============================================================================
|
| 121 |
+
# Learning Rate Scheduler - 3-Phase Cosine with Warmup (Same as Pretrain)
|
| 122 |
+
# ============================================================================
|
| 123 |
+
# Phases:
|
| 124 |
+
# 1. Warmup: 0 → 1.0 (300 steps, ~1.4% of training)
|
| 125 |
+
# 2. Steady: 1.0 (constant for 5% of training steps at peak LR)
|
| 126 |
+
# 3. Decay: 1.0 → 0.1 (cosine decay for remaining ~95%)
|
| 127 |
+
|
| 128 |
+
scheduler:
|
| 129 |
+
scheduler_type: cosineWarmup
|
| 130 |
+
warmup_steps: 300
|
| 131 |
+
warmup_ratio: 0.0 # Ignored if warmup_steps > 0
|
| 132 |
+
steady_ratio: 0.05 # 5% of total training steps at peak LR
|
| 133 |
+
min_lr_ratio: 0.1 # Decay to 10% of peak LR
|
| 134 |
+
num_cycles: 0.5 # For compatibility (not used in 3-phase schedule)
|
| 135 |
+
|
| 136 |
+
# ============================================================================
|
| 137 |
+
# Data Type and Device
|
| 138 |
+
# ============================================================================
|
| 139 |
+
dtype: bfloat16 # Use BF16 for better convergence with extended context
|
| 140 |
+
device: cuda # Use GPU for training
|
| 141 |
+
|
| 142 |
+
# ============================================================================
|
| 143 |
+
# Checkpointing and Validation
|
| 144 |
+
# ============================================================================
|
| 145 |
+
# Load pretrained checkpoint and continue training
|
| 146 |
+
checkpoint_path: checkpoints/pretrain_taonet/best_model.pt
|
| 147 |
+
checkpoint_dir: checkpoints/yarn_taonet
|
| 148 |
+
save_every_steps: 512 # More frequent saves for 50k samples (200 updates total)
|
| 149 |
+
save_best_model: true
|
| 150 |
+
keep_last_n_checkpoints: 3
|
| 151 |
+
|
| 152 |
+
# Validation every 512 steps (10% of 50k samples)
|
| 153 |
+
eval_every_steps: 512
|
| 154 |
+
eval_samples: 2500 # Reduced from 8000
|
| 155 |
+
|
| 156 |
+
# ============================================================================
|
| 157 |
+
# Logging
|
| 158 |
+
# ============================================================================
|
| 159 |
+
log_every_steps: 50 # Log every 50 updates
|
| 160 |
+
aim_repo: .aim
|
| 161 |
+
|
| 162 |
+
# ============================================================================
|
| 163 |
+
# Miscellaneous
|
| 164 |
+
# ============================================================================
|
| 165 |
+
seed: 42
|
| 166 |
+
num_workers: 0
|
| 167 |
+
pin_memory: true
|
| 168 |
+
|
| 169 |
+
# ============================================================================
|
| 170 |
+
# YaRN Performance Notes
|
| 171 |
+
# ============================================================================
|
| 172 |
+
# Expected memory usage: ~1.5x of pretrain (8x longer seq, half batch)
|
| 173 |
+
# Expected training time: ~50-100 steps/min on H100 (depends on setup)
|
| 174 |
+
# Expected convergence: Loss should decrease over 50k samples; monitor perplexity on 8192-length sequences
|
| 175 |
+
#
|
| 176 |
+
# Tuning recommendations for iterative improvements:
|
| 177 |
+
# 1. If loss is unstable: Reduce learning_rate further (1.25e-3)
|
| 178 |
+
# 2. If loss plateaus quickly: Increase max_samples (100k-200k)
|
| 179 |
+
# 3. If memory OOM: Reduce batch_size to 8 (maintain grad_accum at 16)
|
| 180 |
+
# 4. To speed context expansion: Reduce yarn_alpha to 0.5 (more aggressive)
|
| 181 |
+
# 5. For safer training: Increase yarn_alpha to 2.0 (more conservative)
|
code/TaoTrain/data/sample_data.jsonl
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"text": "The quick brown fox jumps over the lazy dog."}
|
| 2 |
+
{"text": "Python is a powerful programming language used for data science, machine learning, and web development."}
|
| 3 |
+
{"text": "Artificial intelligence and machine learning are transforming industries and creating new opportunities."}
|
| 4 |
+
{"text": "Natural language processing enables computers to understand and generate human language."}
|
| 5 |
+
{"text": "Deep learning models like transformers have revolutionized the field of artificial intelligence."}
|
| 6 |
+
{"text": "Transfer learning allows us to leverage pre-trained models to solve new tasks more efficiently."}
|
| 7 |
+
{"text": "The transformer architecture introduced attention mechanisms that became fundamental to modern NLP."}
|
| 8 |
+
{"text": "Language models trained on large corpora can perform impressive few-shot learning tasks."}
|
| 9 |
+
{"text": "Tokenization is a crucial preprocessing step in natural language processing pipelines."}
|
| 10 |
+
{"text": "SentencePiece is a language-independent tokenization algorithm that handles subword segmentation."}
|
code/TaoTrain/data/sample_sft_multi.jsonl
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"turns": [{"user": "What is machine learning?", "assistant": "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed. It uses algorithms to identify patterns in data."}, {"user": "Can you give me some examples?", "assistant": "Sure! Examples include: 1) Email spam detection - learns to identify spam patterns, 2) Image recognition - learns to identify objects in images, 3) Recommendation systems - learns user preferences to suggest content, 4) Natural language processing - learns language patterns for tasks like translation."}]}
|
| 2 |
+
{"turns": [{"user": "Explain neural networks", "assistant": "Neural networks are computing systems inspired by biological neural networks. They consist of interconnected nodes (neurons) organized in layers. Each connection has a weight that gets adjusted during training to learn patterns from data."}, {"user": "What layers do they have?", "assistant": "Neural networks typically have three types of layers: 1) Input layer - receives raw data, 2) Hidden layers - process data through weighted connections and activation functions, 3) Output layer - produces the final prediction or classification. Deep networks have multiple hidden layers."}]}
|
| 3 |
+
{"turns": [{"user": "What is a transformer model?", "assistant": "A transformer is a neural network architecture that uses self-attention mechanisms to weigh the importance of different input elements. It processes all input tokens in parallel (unlike RNNs) and captures long-range dependencies effectively. Transformers form the basis of modern large language models like GPT and BERT."}, {"user": "How does attention work?", "assistant": "Attention computes relevance scores between query tokens and all other tokens (keys and values). It learns which tokens to focus on when processing each position. Multi-head attention applies this in parallel with multiple learned representations. This allows the model to attend to different types of relationships simultaneously."}, {"user": "Why is it better than RNNs?", "assistant": "Transformers are better than RNNs because: 1) Parallel processing - all tokens processed simultaneously instead of sequentially, 2) Better long-range dependencies - attention can directly connect distant tokens, 3) Easier to train - no vanishing gradient problems, 4) Scalability - transformers scale better to large datasets and models."}]}
|
| 4 |
+
{"turns": [{"user": "What is the difference between supervised and unsupervised learning?", "assistant": "Supervised learning uses labeled data (input-output pairs) to train models, like predicting house prices from features. Unsupervised learning finds patterns in unlabeled data, like clustering similar customers. Supervised learning is more accurate but requires labeled data; unsupervised is useful for exploration."}, {"user": "What about semi-supervised learning?", "assistant": "Semi-supervised learning uses both labeled and unlabeled data. It's useful when labeling is expensive. The unlabeled data helps improve performance by providing additional patterns. Common techniques include pseudo-labeling (labeling unlabeled data with the current model) and consistency regularization."}]}
|
code/TaoTrain/data/sample_sft_single.jsonl
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"input": "What is Python?", "output": "Python is a high-level, interpreted programming language known for its simplicity and readability. It supports multiple programming paradigms including procedural, object-oriented, and functional programming."}
|
| 2 |
+
{"input": "Explain recursion", "output": "Recursion is a programming technique where a function calls itself to solve a problem by breaking it down into smaller subproblems. Base cases prevent infinite recursion. Example: factorial(n) = n * factorial(n-1) for n > 1, else 1."}
|
| 3 |
+
{"input": "What are decorators in Python?", "output": "Decorators are functions that modify or wrap other functions or classes without permanently changing their source code. They use the @ symbol and are executed when the function is defined. Example: @staticmethod, @property, or custom decorators."}
|
| 4 |
+
{"input": "How do you handle exceptions in Python?", "output": "Exceptions are handled using try-except-else-finally blocks. The try block contains code that might raise an exception, except catches specific exceptions, else runs if no exception occurred, and finally runs regardless. Custom exceptions can be created by inheriting from Exception."}
|
| 5 |
+
{"input": "What is the difference between lists and tuples?", "output": "Lists are mutable (can be changed after creation) while tuples are immutable (cannot be changed). Lists use square brackets [] and tuples use parentheses (). Both can contain mixed data types. Tuples are faster and can be used as dictionary keys."}
|
code/TaoTrain/pyproject.toml
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "taoTrain"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Clean, modular PyTorch LLM training framework with pluggable architectures, AimStack logging, and TUI inference"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
license = { text = "MIT" }
|
| 12 |
+
authors = [
|
| 13 |
+
{ name = "Felix", email = "felix@example.com" }
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
dependencies = [
|
| 17 |
+
"torch>=2.0.0",
|
| 18 |
+
"transformers>=4.30.0",
|
| 19 |
+
"datasets>=2.10.0",
|
| 20 |
+
"pydantic>=2.0.0",
|
| 21 |
+
"pydantic-settings>=2.0.0",
|
| 22 |
+
"aim>=3.15.0",
|
| 23 |
+
"click>=8.1.0",
|
| 24 |
+
"rich>=13.0.0",
|
| 25 |
+
"textual>=0.30.0",
|
| 26 |
+
"numpy>=1.24.0",
|
| 27 |
+
"tqdm>=4.65.0",
|
| 28 |
+
"sentencepiece>=0.1.99",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=7.4.0",
|
| 34 |
+
"pytest-cov>=4.1.0",
|
| 35 |
+
"pytest-xdist>=3.3.0",
|
| 36 |
+
"black>=23.7.0",
|
| 37 |
+
"ruff>=0.0.280",
|
| 38 |
+
"typing-extensions>=4.7.0",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
[project.scripts]
|
| 42 |
+
train = "taoTrain.cli:main"
|
| 43 |
+
train-tokenizer = "taoTrain.cli:train_tokenizer_command"
|
| 44 |
+
tui-chat = "taoTrain.inference.tui:main"
|
| 45 |
+
|
| 46 |
+
[tool.setuptools.packages.find]
|
| 47 |
+
where = ["src"]
|
| 48 |
+
|
| 49 |
+
[tool.setuptools.package-data]
|
| 50 |
+
taoTrain = ["configs/**/*.yaml"]
|
| 51 |
+
|
| 52 |
+
[tool.black]
|
| 53 |
+
line-length = 100
|
| 54 |
+
target-version = ["py310"]
|
| 55 |
+
|
| 56 |
+
[tool.ruff]
|
| 57 |
+
line-length = 100
|
| 58 |
+
target-version = "py310"
|
| 59 |
+
select = ["E", "F", "W", "I", "N", "UP", "RUF"]
|
| 60 |
+
ignore = ["E501"]
|
| 61 |
+
|
| 62 |
+
[tool.pytest.ini_options]
|
| 63 |
+
testpaths = ["tests"]
|
| 64 |
+
python_files = "test_*.py"
|
| 65 |
+
addopts = "--verbose"
|
code/TaoTrain/scripts/benchmark_taonet_real_tokens.py
ADDED
|
@@ -0,0 +1,1034 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Real-text token benchmark for TaoNet attention vs TaoNet-SSM.
|
| 2 |
+
|
| 3 |
+
This script consumes a JSONL or plain-text corpus, tokenizes it into one long
|
| 4 |
+
stream, samples contiguous next-token batches, and compares TaoNet variants
|
| 5 |
+
with identical outer dimensions. It is intentionally lighter than the full
|
| 6 |
+
trainer so it can be used in the RepoBridge iteration loop.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import hashlib
|
| 13 |
+
from contextlib import nullcontext
|
| 14 |
+
from contextlib import redirect_stdout
|
| 15 |
+
import csv
|
| 16 |
+
import io
|
| 17 |
+
import json
|
| 18 |
+
import math
|
| 19 |
+
import os
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
import platform
|
| 22 |
+
import random
|
| 23 |
+
import sys
|
| 24 |
+
import time
|
| 25 |
+
from typing import Any, Iterable
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 30 |
+
SRC_ROOT = REPO_ROOT / "src"
|
| 31 |
+
if str(SRC_ROOT) not in sys.path:
|
| 32 |
+
sys.path.insert(0, str(SRC_ROOT))
|
| 33 |
+
|
| 34 |
+
from taoTrain.config import ModelConfig
|
| 35 |
+
from taoTrain.models import get_model
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
DTYPES = {
|
| 39 |
+
"float32": torch.float32,
|
| 40 |
+
"fp32": torch.float32,
|
| 41 |
+
"float16": torch.float16,
|
| 42 |
+
"fp16": torch.float16,
|
| 43 |
+
"bfloat16": torch.bfloat16,
|
| 44 |
+
"bf16": torch.bfloat16,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
BYTE_PAD_ID = 0
|
| 48 |
+
BYTE_EOS_ID = 1
|
| 49 |
+
BYTE_UNK_ID = 2
|
| 50 |
+
BYTE_OFFSET = 3
|
| 51 |
+
BYTE_VOCAB_SIZE = 259
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def parse_int_list(value: str) -> list[int]:
|
| 55 |
+
return [int(item.strip()) for item in value.split(",") if item.strip()]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def parse_float_list(value: str) -> list[float]:
|
| 59 |
+
return [float(item.strip()) for item in value.split(",") if item.strip()]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def parse_str_list(value: str) -> list[str]:
|
| 63 |
+
return [item.strip() for item in value.split(",") if item.strip()]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def stable_case_id(parts: dict[str, Any]) -> str:
|
| 67 |
+
payload = json.dumps(parts, sort_keys=True, separators=(",", ":"))
|
| 68 |
+
return hashlib.sha1(payload.encode("utf-8")).hexdigest()[:16]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def make_case_parts(
|
| 72 |
+
args: argparse.Namespace,
|
| 73 |
+
*,
|
| 74 |
+
architecture: str,
|
| 75 |
+
ssm_hidden_dim: int | None,
|
| 76 |
+
ssm_mixer_dim: int | None,
|
| 77 |
+
ssm_num_lanes: int | None,
|
| 78 |
+
hybrid_pattern: str | None,
|
| 79 |
+
ssm_gate_type: str | None,
|
| 80 |
+
learning_rate: float,
|
| 81 |
+
weight_decay: float,
|
| 82 |
+
batch_size: int,
|
| 83 |
+
) -> dict[str, Any]:
|
| 84 |
+
uses_ssm = architecture in {"taonet_ssm", "taonet_hybrid"}
|
| 85 |
+
return {
|
| 86 |
+
"architecture": architecture,
|
| 87 |
+
"batch_size": batch_size,
|
| 88 |
+
"seq_len": args.seq_len,
|
| 89 |
+
"hidden_dim": args.hidden_dim,
|
| 90 |
+
"num_layers": args.num_layers,
|
| 91 |
+
"num_heads": args.num_heads,
|
| 92 |
+
"hidden_dim_ff": args.hidden_dim_ff,
|
| 93 |
+
"d_latent_kv": args.d_latent_kv,
|
| 94 |
+
"learning_rate": learning_rate,
|
| 95 |
+
"weight_decay": weight_decay,
|
| 96 |
+
"train_steps": args.train_steps,
|
| 97 |
+
"eval_batches": args.eval_batches,
|
| 98 |
+
"ssm_core": args.ssm_core if uses_ssm else None,
|
| 99 |
+
"ssm_hidden_dim": ssm_hidden_dim if uses_ssm else None,
|
| 100 |
+
"ssm_mixer_dim": ssm_mixer_dim if uses_ssm else None,
|
| 101 |
+
"ssm_num_lanes": ssm_num_lanes if uses_ssm else None,
|
| 102 |
+
"ssm_lane_mode": args.ssm_lane_mode if uses_ssm else None,
|
| 103 |
+
"ssm_split_mix": args.ssm_split_mix if uses_ssm else None,
|
| 104 |
+
"ssm_gate_type": ssm_gate_type if uses_ssm else None,
|
| 105 |
+
"ssm_branch_rms_norm": args.ssm_branch_rms_norm if uses_ssm else None,
|
| 106 |
+
"ssm_branch_clip_value": args.ssm_branch_clip_value if uses_ssm else None,
|
| 107 |
+
"block_residual_rms_norm": args.block_residual_rms_norm if uses_ssm else None,
|
| 108 |
+
"block_residual_rms_target": args.block_residual_rms_target if uses_ssm else None,
|
| 109 |
+
"block_residual_rms_cap": args.block_residual_rms_cap if uses_ssm else None,
|
| 110 |
+
"hybrid_pattern": hybrid_pattern if architecture == "taonet_hybrid" else None,
|
| 111 |
+
"hybrid_ssm_layers": args.hybrid_ssm_layers if architecture == "taonet_hybrid" else None,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def synchronize(device: torch.device) -> None:
|
| 116 |
+
if device.type == "cuda":
|
| 117 |
+
torch.cuda.synchronize(device)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def reset_memory(device: torch.device) -> None:
|
| 121 |
+
if device.type == "cuda":
|
| 122 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def memory_stats(device: torch.device) -> dict[str, float | None]:
|
| 126 |
+
if device.type != "cuda":
|
| 127 |
+
return {"peak_allocated_mb": None, "peak_reserved_mb": None}
|
| 128 |
+
return {
|
| 129 |
+
"peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
|
| 130 |
+
"peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def iter_texts(path: Path, *, text_field: str, max_records: int | None) -> Iterable[str]:
|
| 135 |
+
suffix = path.suffix.lower()
|
| 136 |
+
count = 0
|
| 137 |
+
with path.open("r", encoding="utf-8", errors="replace") as handle:
|
| 138 |
+
if suffix in {".jsonl", ".json"}:
|
| 139 |
+
for line in handle:
|
| 140 |
+
if max_records is not None and count >= max_records:
|
| 141 |
+
break
|
| 142 |
+
line = line.strip()
|
| 143 |
+
if not line:
|
| 144 |
+
continue
|
| 145 |
+
try:
|
| 146 |
+
record = json.loads(line)
|
| 147 |
+
except json.JSONDecodeError:
|
| 148 |
+
continue
|
| 149 |
+
text = record.get(text_field)
|
| 150 |
+
if isinstance(text, str) and text:
|
| 151 |
+
count += 1
|
| 152 |
+
yield text
|
| 153 |
+
else:
|
| 154 |
+
for line in handle:
|
| 155 |
+
if max_records is not None and count >= max_records:
|
| 156 |
+
break
|
| 157 |
+
line = line.rstrip("\n")
|
| 158 |
+
if line:
|
| 159 |
+
count += 1
|
| 160 |
+
yield line
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def load_sentencepiece(path: Path):
|
| 164 |
+
import sentencepiece as spm
|
| 165 |
+
|
| 166 |
+
processor = spm.SentencePieceProcessor()
|
| 167 |
+
processor.load(str(path))
|
| 168 |
+
return processor
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def encode_text(text: str, *, tokenizer_type: str, sentencepiece_processor=None) -> list[int]:
|
| 172 |
+
if tokenizer_type == "byte":
|
| 173 |
+
return [byte + BYTE_OFFSET for byte in text.encode("utf-8", errors="replace")] + [BYTE_EOS_ID]
|
| 174 |
+
if tokenizer_type == "sentencepiece":
|
| 175 |
+
if sentencepiece_processor is None:
|
| 176 |
+
raise ValueError("sentencepiece tokenizer requested without a processor.")
|
| 177 |
+
ids = list(sentencepiece_processor.encode(text, out_type=int))
|
| 178 |
+
eos_id = sentencepiece_processor.eos_id()
|
| 179 |
+
if eos_id >= 0:
|
| 180 |
+
ids.append(eos_id)
|
| 181 |
+
return ids
|
| 182 |
+
raise ValueError(f"Unsupported tokenizer type '{tokenizer_type}'.")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def load_token_stream(args: argparse.Namespace) -> tuple[torch.Tensor, int]:
|
| 186 |
+
path = Path(args.data_path)
|
| 187 |
+
if not path.exists():
|
| 188 |
+
raise FileNotFoundError(f"Data path not found: {path}")
|
| 189 |
+
|
| 190 |
+
sp = None
|
| 191 |
+
if args.tokenizer_type == "sentencepiece":
|
| 192 |
+
if not args.tokenizer_path:
|
| 193 |
+
raise ValueError("--tokenizer-path is required for sentencepiece tokenization.")
|
| 194 |
+
sp = load_sentencepiece(Path(args.tokenizer_path))
|
| 195 |
+
vocab_size = int(sp.vocab_size())
|
| 196 |
+
else:
|
| 197 |
+
vocab_size = BYTE_VOCAB_SIZE
|
| 198 |
+
|
| 199 |
+
tokens: list[int] = []
|
| 200 |
+
for text in iter_texts(path, text_field=args.text_field, max_records=args.max_records):
|
| 201 |
+
tokens.extend(encode_text(text, tokenizer_type=args.tokenizer_type, sentencepiece_processor=sp))
|
| 202 |
+
if args.max_tokens is not None and len(tokens) >= args.max_tokens:
|
| 203 |
+
tokens = tokens[: args.max_tokens]
|
| 204 |
+
break
|
| 205 |
+
|
| 206 |
+
if len(tokens) < args.seq_len + 2:
|
| 207 |
+
raise ValueError(
|
| 208 |
+
f"Need at least {args.seq_len + 2} tokens, found {len(tokens)} from {path}."
|
| 209 |
+
)
|
| 210 |
+
return torch.tensor(tokens, dtype=torch.long), vocab_size
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def split_stream(tokens: torch.Tensor, eval_fraction: float) -> tuple[torch.Tensor, torch.Tensor]:
|
| 214 |
+
split = int(tokens.numel() * (1.0 - eval_fraction))
|
| 215 |
+
split = max(2, min(split, tokens.numel() - 2))
|
| 216 |
+
return tokens[:split].contiguous(), tokens[split:].contiguous()
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def sample_batch(
|
| 220 |
+
stream: torch.Tensor,
|
| 221 |
+
*,
|
| 222 |
+
batch_size: int,
|
| 223 |
+
seq_len: int,
|
| 224 |
+
device: torch.device,
|
| 225 |
+
generator: torch.Generator,
|
| 226 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 227 |
+
max_start = stream.numel() - seq_len - 1
|
| 228 |
+
if max_start <= 0:
|
| 229 |
+
raise ValueError(f"Token stream too short for seq_len={seq_len}.")
|
| 230 |
+
starts = torch.randint(0, max_start, (batch_size,), generator=generator)
|
| 231 |
+
rows = [stream[int(start) : int(start) + seq_len + 1] for start in starts]
|
| 232 |
+
batch = torch.stack(rows, dim=0).to(device=device)
|
| 233 |
+
input_ids = batch[:, :-1].contiguous()
|
| 234 |
+
labels = batch[:, 1:].contiguous()
|
| 235 |
+
attention_mask = torch.ones_like(input_ids)
|
| 236 |
+
return input_ids, labels, attention_mask
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def token_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float:
|
| 240 |
+
predictions = torch.argmax(logits, dim=-1)
|
| 241 |
+
correct = predictions == labels
|
| 242 |
+
return float(correct.sum().detach().cpu() / correct.numel())
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def build_config(
|
| 246 |
+
args: argparse.Namespace,
|
| 247 |
+
*,
|
| 248 |
+
architecture: str,
|
| 249 |
+
vocab_size: int,
|
| 250 |
+
ssm_hidden_dim: int | None,
|
| 251 |
+
ssm_mixer_dim: int | None,
|
| 252 |
+
ssm_num_lanes: int | None,
|
| 253 |
+
hybrid_pattern: str | None,
|
| 254 |
+
ssm_gate_type: str | None,
|
| 255 |
+
) -> ModelConfig:
|
| 256 |
+
uses_ssm = architecture in {"taonet_ssm", "taonet_hybrid"}
|
| 257 |
+
d_latent_kv = args.d_latent_kv if args.d_latent_kv is not None else int(args.hidden_dim * 0.75)
|
| 258 |
+
d_rope = args.d_rope if args.d_rope is not None else args.hidden_dim // args.num_heads
|
| 259 |
+
hidden_dim_ff = args.hidden_dim_ff if args.hidden_dim_ff is not None else args.hidden_dim * 4
|
| 260 |
+
return ModelConfig(
|
| 261 |
+
architecture_type=architecture,
|
| 262 |
+
vocab_size=vocab_size,
|
| 263 |
+
hidden_dim=args.hidden_dim,
|
| 264 |
+
num_layers=args.num_layers,
|
| 265 |
+
num_heads=args.num_heads,
|
| 266 |
+
max_seq_length=args.seq_len,
|
| 267 |
+
d_latent_kv=d_latent_kv,
|
| 268 |
+
d_rope=d_rope,
|
| 269 |
+
hidden_dim_ff=hidden_dim_ff,
|
| 270 |
+
dropout=args.dropout,
|
| 271 |
+
gqa_groups=args.gqa_groups,
|
| 272 |
+
rope_scale=args.rope_scale,
|
| 273 |
+
yarn_alpha=args.yarn_alpha,
|
| 274 |
+
init_std=args.init_std,
|
| 275 |
+
ssm_core=args.ssm_core,
|
| 276 |
+
ssm_hidden_dim=ssm_hidden_dim,
|
| 277 |
+
ssm_mixer_dim=ssm_mixer_dim,
|
| 278 |
+
ssm_num_lanes=ssm_num_lanes or args.ssm_num_lanes,
|
| 279 |
+
ssm_lane_combine=args.ssm_lane_combine,
|
| 280 |
+
ssm_lane_mode=args.ssm_lane_mode,
|
| 281 |
+
ssm_split_mix=args.ssm_split_mix,
|
| 282 |
+
ssm_rank=args.ssm_rank,
|
| 283 |
+
ssm_max_low_rank_scale=args.ssm_max_low_rank_scale,
|
| 284 |
+
ssm_finite_tail_correction=args.ssm_finite_tail_correction,
|
| 285 |
+
ssm_kernel_mode=args.ssm_kernel_mode,
|
| 286 |
+
ssm_kernel_threshold=args.ssm_kernel_threshold,
|
| 287 |
+
ssm_dt_min=args.ssm_dt_min,
|
| 288 |
+
ssm_dt_max=args.ssm_dt_max,
|
| 289 |
+
ssm_dt_init=args.ssm_dt_init,
|
| 290 |
+
ssm_use_padding_mask=False,
|
| 291 |
+
ssm_activation=args.ssm_activation,
|
| 292 |
+
ssm_gate=args.ssm_gate,
|
| 293 |
+
ssm_input_gate=args.ssm_input_gate,
|
| 294 |
+
ssm_gate_type=ssm_gate_type or args.ssm_gate_type,
|
| 295 |
+
ssm_layer_scale_init=args.ssm_layer_scale_init,
|
| 296 |
+
ssm_branch_rms_norm=args.ssm_branch_rms_norm if uses_ssm else False,
|
| 297 |
+
ssm_branch_clip_value=args.ssm_branch_clip_value if uses_ssm else None,
|
| 298 |
+
block_residual_rms_norm=args.block_residual_rms_norm if uses_ssm else False,
|
| 299 |
+
block_residual_rms_target=args.block_residual_rms_target,
|
| 300 |
+
block_residual_rms_cap=args.block_residual_rms_cap if uses_ssm else None,
|
| 301 |
+
ssm_local_shift=args.ssm_local_shift,
|
| 302 |
+
ssm_local_shift_init=args.ssm_local_shift_init,
|
| 303 |
+
ssm_local_shift_per_channel=args.ssm_local_shift_per_channel,
|
| 304 |
+
hybrid_pattern=hybrid_pattern or args.hybrid_pattern,
|
| 305 |
+
hybrid_ssm_layers=args.hybrid_ssm_layers or None,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def time_repeats(fn, *, device: torch.device, warmup: int, repeats: int) -> tuple[float, float, float]:
|
| 310 |
+
last_loss = float("nan")
|
| 311 |
+
for _ in range(warmup):
|
| 312 |
+
last_loss = fn()
|
| 313 |
+
synchronize(device)
|
| 314 |
+
|
| 315 |
+
latencies = []
|
| 316 |
+
for _ in range(repeats):
|
| 317 |
+
reset_memory(device)
|
| 318 |
+
synchronize(device)
|
| 319 |
+
start = time.perf_counter()
|
| 320 |
+
last_loss = fn()
|
| 321 |
+
synchronize(device)
|
| 322 |
+
latencies.append(time.perf_counter() - start)
|
| 323 |
+
return sum(latencies) / len(latencies), min(latencies), last_loss
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def evaluate_model(
|
| 327 |
+
model: torch.nn.Module,
|
| 328 |
+
*,
|
| 329 |
+
eval_stream: torch.Tensor,
|
| 330 |
+
args: argparse.Namespace,
|
| 331 |
+
batch_size: int,
|
| 332 |
+
device: torch.device,
|
| 333 |
+
generator: torch.Generator,
|
| 334 |
+
autocast_context,
|
| 335 |
+
) -> tuple[float, float]:
|
| 336 |
+
model.eval()
|
| 337 |
+
losses = []
|
| 338 |
+
accuracies = []
|
| 339 |
+
with torch.no_grad():
|
| 340 |
+
for _ in range(args.eval_batches):
|
| 341 |
+
input_ids, labels, attention_mask = sample_batch(
|
| 342 |
+
eval_stream,
|
| 343 |
+
batch_size=batch_size,
|
| 344 |
+
seq_len=args.seq_len,
|
| 345 |
+
device=device,
|
| 346 |
+
generator=generator,
|
| 347 |
+
)
|
| 348 |
+
with autocast_context():
|
| 349 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 350 |
+
losses.append(float(outputs["loss"].detach().cpu()))
|
| 351 |
+
accuracies.append(token_accuracy(outputs["logits"], labels))
|
| 352 |
+
model.train()
|
| 353 |
+
return sum(losses) / len(losses), sum(accuracies) / len(accuracies)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def train_model(
|
| 357 |
+
model: torch.nn.Module,
|
| 358 |
+
*,
|
| 359 |
+
train_stream: torch.Tensor,
|
| 360 |
+
args: argparse.Namespace,
|
| 361 |
+
learning_rate: float,
|
| 362 |
+
weight_decay: float,
|
| 363 |
+
batch_size: int,
|
| 364 |
+
device: torch.device,
|
| 365 |
+
generator: torch.Generator,
|
| 366 |
+
autocast_context,
|
| 367 |
+
) -> tuple[float | None, float | None, float | None, float | None, list[dict[str, float | int]]]:
|
| 368 |
+
if args.train_steps <= 0:
|
| 369 |
+
return None, None, None, None, []
|
| 370 |
+
|
| 371 |
+
model.train()
|
| 372 |
+
optimizer = torch.optim.AdamW(
|
| 373 |
+
model.parameters(),
|
| 374 |
+
lr=learning_rate,
|
| 375 |
+
weight_decay=weight_decay,
|
| 376 |
+
)
|
| 377 |
+
last_loss = float("nan")
|
| 378 |
+
last_grad_norm: float | None = None
|
| 379 |
+
max_grad_norm_seen: float | None = None
|
| 380 |
+
history: list[dict[str, float | int]] = []
|
| 381 |
+
start = time.perf_counter()
|
| 382 |
+
trainable_params = [param for param in model.parameters() if param.requires_grad]
|
| 383 |
+
log_every = max(0, int(args.train_log_every))
|
| 384 |
+
for step in range(1, args.train_steps + 1):
|
| 385 |
+
input_ids, labels, attention_mask = sample_batch(
|
| 386 |
+
train_stream,
|
| 387 |
+
batch_size=batch_size,
|
| 388 |
+
seq_len=args.seq_len,
|
| 389 |
+
device=device,
|
| 390 |
+
generator=generator,
|
| 391 |
+
)
|
| 392 |
+
optimizer.zero_grad(set_to_none=True)
|
| 393 |
+
with autocast_context():
|
| 394 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 395 |
+
loss = outputs["loss"]
|
| 396 |
+
loss.backward()
|
| 397 |
+
if args.max_grad_norm and args.max_grad_norm > 0:
|
| 398 |
+
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
|
| 399 |
+
last_grad_norm = float(grad_norm_tensor.detach().cpu())
|
| 400 |
+
max_grad_norm_seen = (
|
| 401 |
+
last_grad_norm
|
| 402 |
+
if max_grad_norm_seen is None
|
| 403 |
+
else max(max_grad_norm_seen, last_grad_norm)
|
| 404 |
+
)
|
| 405 |
+
optimizer.step()
|
| 406 |
+
last_loss = float(loss.detach().cpu())
|
| 407 |
+
if log_every and (step == 1 or step % log_every == 0 or step == args.train_steps):
|
| 408 |
+
history.append(
|
| 409 |
+
{
|
| 410 |
+
"step": step,
|
| 411 |
+
"loss": last_loss,
|
| 412 |
+
"grad_norm": last_grad_norm if last_grad_norm is not None else float("nan"),
|
| 413 |
+
}
|
| 414 |
+
)
|
| 415 |
+
synchronize(device)
|
| 416 |
+
return last_loss, time.perf_counter() - start, last_grad_norm, max_grad_norm_seen, history
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def save_case_checkpoint(
|
| 420 |
+
*,
|
| 421 |
+
args: argparse.Namespace,
|
| 422 |
+
model: torch.nn.Module,
|
| 423 |
+
config: ModelConfig,
|
| 424 |
+
case_id: str,
|
| 425 |
+
case_parts: dict[str, Any],
|
| 426 |
+
train_final_loss: float | None,
|
| 427 |
+
train_seconds: float | None,
|
| 428 |
+
train_history: list[dict[str, float | int]],
|
| 429 |
+
eval_loss: float,
|
| 430 |
+
eval_accuracy: float,
|
| 431 |
+
) -> str | None:
|
| 432 |
+
if not args.save_case_checkpoints:
|
| 433 |
+
return None
|
| 434 |
+
checkpoint_dir = Path(args.checkpoint_dir) if args.checkpoint_dir else Path(args.output_dir) / "checkpoints"
|
| 435 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 436 |
+
path = checkpoint_dir / f"{case_id}.pt"
|
| 437 |
+
checkpoint = {
|
| 438 |
+
"step": args.train_steps,
|
| 439 |
+
"model_state": model.state_dict(),
|
| 440 |
+
"optimizer_state": None,
|
| 441 |
+
"config": {"model": config.to_dict(), "benchmark_args": vars(args)},
|
| 442 |
+
"metrics": {
|
| 443 |
+
"train_final_loss": train_final_loss,
|
| 444 |
+
"train_seconds": train_seconds,
|
| 445 |
+
"train_history": train_history,
|
| 446 |
+
"eval_loss": eval_loss,
|
| 447 |
+
"eval_accuracy": eval_accuracy,
|
| 448 |
+
},
|
| 449 |
+
"case_id": case_id,
|
| 450 |
+
"case_parts": case_parts,
|
| 451 |
+
}
|
| 452 |
+
torch.save(checkpoint, path)
|
| 453 |
+
latest_path = checkpoint_dir / "latest.pt"
|
| 454 |
+
try:
|
| 455 |
+
torch.save(checkpoint, latest_path)
|
| 456 |
+
except Exception:
|
| 457 |
+
pass
|
| 458 |
+
return str(path)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def benchmark_case(
|
| 462 |
+
*,
|
| 463 |
+
args: argparse.Namespace,
|
| 464 |
+
architecture: str,
|
| 465 |
+
ssm_hidden_dim: int | None,
|
| 466 |
+
ssm_mixer_dim: int | None,
|
| 467 |
+
ssm_num_lanes: int | None,
|
| 468 |
+
hybrid_pattern: str | None,
|
| 469 |
+
ssm_gate_type: str | None,
|
| 470 |
+
learning_rate: float,
|
| 471 |
+
weight_decay: float,
|
| 472 |
+
vocab_size: int,
|
| 473 |
+
train_stream: torch.Tensor,
|
| 474 |
+
eval_stream: torch.Tensor,
|
| 475 |
+
batch_size: int,
|
| 476 |
+
dtype: torch.dtype,
|
| 477 |
+
device: torch.device,
|
| 478 |
+
) -> list[dict[str, Any]]:
|
| 479 |
+
seed_offset = 0 if architecture == "taonet" else int(ssm_hidden_dim or 0) + int(ssm_mixer_dim or 0)
|
| 480 |
+
seed_offset += int(ssm_num_lanes or 0) * 31
|
| 481 |
+
if hybrid_pattern:
|
| 482 |
+
seed_offset += sum(ord(char) for char in hybrid_pattern)
|
| 483 |
+
if ssm_gate_type:
|
| 484 |
+
seed_offset += sum(ord(char) for char in ssm_gate_type)
|
| 485 |
+
case_seed = args.seed + seed_offset + batch_size
|
| 486 |
+
torch.manual_seed(case_seed)
|
| 487 |
+
if device.type == "cuda":
|
| 488 |
+
torch.cuda.manual_seed_all(case_seed)
|
| 489 |
+
cpu_generator = torch.Generator().manual_seed(case_seed)
|
| 490 |
+
case_parts = make_case_parts(
|
| 491 |
+
args,
|
| 492 |
+
architecture=architecture,
|
| 493 |
+
ssm_hidden_dim=ssm_hidden_dim,
|
| 494 |
+
ssm_mixer_dim=ssm_mixer_dim,
|
| 495 |
+
ssm_num_lanes=ssm_num_lanes,
|
| 496 |
+
hybrid_pattern=hybrid_pattern,
|
| 497 |
+
ssm_gate_type=ssm_gate_type,
|
| 498 |
+
learning_rate=learning_rate,
|
| 499 |
+
weight_decay=weight_decay,
|
| 500 |
+
batch_size=batch_size,
|
| 501 |
+
)
|
| 502 |
+
case_id = stable_case_id(case_parts)
|
| 503 |
+
config = build_config(
|
| 504 |
+
args,
|
| 505 |
+
architecture=architecture,
|
| 506 |
+
vocab_size=vocab_size,
|
| 507 |
+
ssm_hidden_dim=ssm_hidden_dim,
|
| 508 |
+
ssm_mixer_dim=ssm_mixer_dim,
|
| 509 |
+
ssm_num_lanes=ssm_num_lanes,
|
| 510 |
+
hybrid_pattern=hybrid_pattern,
|
| 511 |
+
ssm_gate_type=ssm_gate_type,
|
| 512 |
+
)
|
| 513 |
+
with redirect_stdout(io.StringIO()):
|
| 514 |
+
model = get_model(config, device=device)
|
| 515 |
+
model.train()
|
| 516 |
+
|
| 517 |
+
if device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}:
|
| 518 |
+
def autocast_context():
|
| 519 |
+
return torch.autocast(device_type="cuda", dtype=dtype, enabled=True)
|
| 520 |
+
else:
|
| 521 |
+
def autocast_context():
|
| 522 |
+
return nullcontext()
|
| 523 |
+
|
| 524 |
+
(
|
| 525 |
+
train_final_loss,
|
| 526 |
+
train_seconds,
|
| 527 |
+
train_last_grad_norm,
|
| 528 |
+
train_max_grad_norm,
|
| 529 |
+
train_history,
|
| 530 |
+
) = train_model(
|
| 531 |
+
model,
|
| 532 |
+
train_stream=train_stream,
|
| 533 |
+
args=args,
|
| 534 |
+
learning_rate=learning_rate,
|
| 535 |
+
weight_decay=weight_decay,
|
| 536 |
+
batch_size=batch_size,
|
| 537 |
+
device=device,
|
| 538 |
+
generator=cpu_generator,
|
| 539 |
+
autocast_context=autocast_context,
|
| 540 |
+
)
|
| 541 |
+
eval_loss, eval_accuracy = evaluate_model(
|
| 542 |
+
model,
|
| 543 |
+
eval_stream=eval_stream,
|
| 544 |
+
args=args,
|
| 545 |
+
batch_size=batch_size,
|
| 546 |
+
device=device,
|
| 547 |
+
generator=cpu_generator,
|
| 548 |
+
autocast_context=autocast_context,
|
| 549 |
+
)
|
| 550 |
+
checkpoint_path = save_case_checkpoint(
|
| 551 |
+
args=args,
|
| 552 |
+
model=model,
|
| 553 |
+
config=config,
|
| 554 |
+
case_id=case_id,
|
| 555 |
+
case_parts=case_parts,
|
| 556 |
+
train_final_loss=train_final_loss,
|
| 557 |
+
train_seconds=train_seconds,
|
| 558 |
+
train_history=train_history,
|
| 559 |
+
eval_loss=eval_loss,
|
| 560 |
+
eval_accuracy=eval_accuracy,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
input_ids, labels, attention_mask = sample_batch(
|
| 564 |
+
eval_stream,
|
| 565 |
+
batch_size=batch_size,
|
| 566 |
+
seq_len=args.seq_len,
|
| 567 |
+
device=device,
|
| 568 |
+
generator=cpu_generator,
|
| 569 |
+
)
|
| 570 |
+
tokens = batch_size * args.seq_len
|
| 571 |
+
total_params = sum(param.numel() for param in model.parameters())
|
| 572 |
+
trainable_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
|
| 573 |
+
rows: list[dict[str, Any]] = []
|
| 574 |
+
uses_ssm = architecture in {"taonet_ssm", "taonet_hybrid"}
|
| 575 |
+
|
| 576 |
+
def add_row(mode: str, mean_s: float, min_s: float, loss: float) -> None:
|
| 577 |
+
perplexity = math.exp(min(eval_loss, 20.0))
|
| 578 |
+
rows.append(
|
| 579 |
+
{
|
| 580 |
+
"architecture": architecture,
|
| 581 |
+
"case_id": case_id,
|
| 582 |
+
"checkpoint_path": checkpoint_path,
|
| 583 |
+
"ssm_core": args.ssm_core if uses_ssm else None,
|
| 584 |
+
"ssm_hidden_dim": ssm_hidden_dim if uses_ssm else None,
|
| 585 |
+
"ssm_mixer_dim": ssm_mixer_dim if uses_ssm else None,
|
| 586 |
+
"ssm_num_lanes": ssm_num_lanes if uses_ssm else None,
|
| 587 |
+
"ssm_lane_combine": args.ssm_lane_combine if uses_ssm else None,
|
| 588 |
+
"ssm_lane_mode": args.ssm_lane_mode if uses_ssm else None,
|
| 589 |
+
"ssm_split_mix": args.ssm_split_mix if uses_ssm else None,
|
| 590 |
+
"ssm_finite_tail_correction": args.ssm_finite_tail_correction if uses_ssm else None,
|
| 591 |
+
"ssm_gate_type": ssm_gate_type if uses_ssm else None,
|
| 592 |
+
"ssm_local_shift": args.ssm_local_shift if uses_ssm else None,
|
| 593 |
+
"ssm_local_shift_per_channel": args.ssm_local_shift_per_channel if uses_ssm else None,
|
| 594 |
+
"ssm_branch_rms_norm": args.ssm_branch_rms_norm if uses_ssm else None,
|
| 595 |
+
"ssm_branch_clip_value": args.ssm_branch_clip_value if uses_ssm else None,
|
| 596 |
+
"block_residual_rms_norm": args.block_residual_rms_norm if uses_ssm else None,
|
| 597 |
+
"block_residual_rms_target": args.block_residual_rms_target if uses_ssm else None,
|
| 598 |
+
"block_residual_rms_cap": args.block_residual_rms_cap if uses_ssm else None,
|
| 599 |
+
"hybrid_pattern": hybrid_pattern if architecture == "taonet_hybrid" else None,
|
| 600 |
+
"hybrid_ssm_layers": args.hybrid_ssm_layers if architecture == "taonet_hybrid" else None,
|
| 601 |
+
"mode": mode,
|
| 602 |
+
"batch_size": batch_size,
|
| 603 |
+
"seq_len": args.seq_len,
|
| 604 |
+
"tokens": tokens,
|
| 605 |
+
"vocab_size": vocab_size,
|
| 606 |
+
"tokenizer_type": args.tokenizer_type,
|
| 607 |
+
"hidden_dim": args.hidden_dim,
|
| 608 |
+
"num_layers": args.num_layers,
|
| 609 |
+
"num_heads": args.num_heads,
|
| 610 |
+
"dtype": str(dtype).replace("torch.", ""),
|
| 611 |
+
"device": str(device),
|
| 612 |
+
"learning_rate": learning_rate,
|
| 613 |
+
"weight_decay": weight_decay,
|
| 614 |
+
"total_params": total_params,
|
| 615 |
+
"trainable_params": trainable_params,
|
| 616 |
+
"mean_ms": mean_s * 1000.0,
|
| 617 |
+
"min_ms": min_s * 1000.0,
|
| 618 |
+
"tokens_per_s_mean": tokens / max(mean_s, 1e-12),
|
| 619 |
+
"tokens_per_s_best": tokens / max(min_s, 1e-12),
|
| 620 |
+
"loss": loss,
|
| 621 |
+
"eval_loss": eval_loss,
|
| 622 |
+
"eval_perplexity": perplexity,
|
| 623 |
+
"eval_accuracy": eval_accuracy,
|
| 624 |
+
"train_final_loss": train_final_loss,
|
| 625 |
+
"train_seconds": train_seconds,
|
| 626 |
+
"train_history": json.dumps(train_history, separators=(",", ":")),
|
| 627 |
+
"train_last_grad_norm": train_last_grad_norm,
|
| 628 |
+
"train_max_grad_norm": train_max_grad_norm,
|
| 629 |
+
"max_grad_norm": args.max_grad_norm,
|
| 630 |
+
**memory_stats(device),
|
| 631 |
+
}
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
def forward_only() -> float:
|
| 635 |
+
with torch.no_grad():
|
| 636 |
+
with autocast_context():
|
| 637 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 638 |
+
return float(outputs["loss"].detach().cpu())
|
| 639 |
+
|
| 640 |
+
mean_s, min_s, loss = time_repeats(
|
| 641 |
+
forward_only,
|
| 642 |
+
device=device,
|
| 643 |
+
warmup=args.warmup,
|
| 644 |
+
repeats=args.repeats,
|
| 645 |
+
)
|
| 646 |
+
add_row("forward", mean_s, min_s, loss)
|
| 647 |
+
|
| 648 |
+
if args.backward:
|
| 649 |
+
def forward_backward() -> float:
|
| 650 |
+
model.zero_grad(set_to_none=True)
|
| 651 |
+
with autocast_context():
|
| 652 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 653 |
+
loss = outputs["loss"]
|
| 654 |
+
loss.backward()
|
| 655 |
+
return float(loss.detach().cpu())
|
| 656 |
+
|
| 657 |
+
mean_s, min_s, loss = time_repeats(
|
| 658 |
+
forward_backward,
|
| 659 |
+
device=device,
|
| 660 |
+
warmup=args.warmup,
|
| 661 |
+
repeats=args.repeats,
|
| 662 |
+
)
|
| 663 |
+
add_row("forward_backward", mean_s, min_s, loss)
|
| 664 |
+
|
| 665 |
+
return rows
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def print_table(rows: list[dict[str, Any]]) -> None:
|
| 669 |
+
columns = [
|
| 670 |
+
"architecture",
|
| 671 |
+
"hybrid_pattern",
|
| 672 |
+
"ssm_hidden_dim",
|
| 673 |
+
"ssm_mixer_dim",
|
| 674 |
+
"ssm_num_lanes",
|
| 675 |
+
"ssm_lane_combine",
|
| 676 |
+
"ssm_lane_mode",
|
| 677 |
+
"ssm_split_mix",
|
| 678 |
+
"ssm_gate_type",
|
| 679 |
+
"ssm_branch_rms_norm",
|
| 680 |
+
"block_residual_rms_norm",
|
| 681 |
+
"block_residual_rms_cap",
|
| 682 |
+
"learning_rate",
|
| 683 |
+
"mode",
|
| 684 |
+
"batch_size",
|
| 685 |
+
"seq_len",
|
| 686 |
+
"mean_ms",
|
| 687 |
+
"tokens_per_s_mean",
|
| 688 |
+
"peak_allocated_mb",
|
| 689 |
+
"eval_loss",
|
| 690 |
+
"eval_perplexity",
|
| 691 |
+
"eval_accuracy",
|
| 692 |
+
"train_max_grad_norm",
|
| 693 |
+
]
|
| 694 |
+
print("\t".join(columns))
|
| 695 |
+
for row in rows:
|
| 696 |
+
values = []
|
| 697 |
+
for column in columns:
|
| 698 |
+
value = row[column]
|
| 699 |
+
if isinstance(value, float):
|
| 700 |
+
values.append(f"{value:.3f}")
|
| 701 |
+
else:
|
| 702 |
+
values.append(str(value))
|
| 703 |
+
print("\t".join(values))
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def write_outputs(rows: list[dict[str, Any]], output_dir: Path, metadata: dict[str, Any]) -> None:
|
| 707 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 708 |
+
json_path = output_dir / "taonet_real_token_benchmark.json"
|
| 709 |
+
csv_path = output_dir / "taonet_real_token_benchmark.csv"
|
| 710 |
+
json_path.write_text(json.dumps({"metadata": metadata, "results": rows}, indent=2), encoding="utf-8")
|
| 711 |
+
|
| 712 |
+
fieldnames = list(rows[0].keys()) if rows else []
|
| 713 |
+
with csv_path.open("w", newline="", encoding="utf-8") as handle:
|
| 714 |
+
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
| 715 |
+
writer.writeheader()
|
| 716 |
+
writer.writerows(rows)
|
| 717 |
+
|
| 718 |
+
print(f"Wrote {json_path}")
|
| 719 |
+
print(f"Wrote {csv_path}")
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def load_existing_rows(output_dir: Path) -> list[dict[str, Any]]:
|
| 723 |
+
csv_path = output_dir / "taonet_real_token_benchmark.csv"
|
| 724 |
+
if not csv_path.exists():
|
| 725 |
+
return []
|
| 726 |
+
with csv_path.open("r", newline="", encoding="utf-8") as handle:
|
| 727 |
+
return list(csv.DictReader(handle))
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
def main() -> None:
|
| 731 |
+
parser = argparse.ArgumentParser(description="Benchmark TaoNet variants on a real text token stream.")
|
| 732 |
+
parser.add_argument("--data-path", required=True)
|
| 733 |
+
parser.add_argument("--text-field", default="text")
|
| 734 |
+
parser.add_argument("--tokenizer-type", choices=["byte", "sentencepiece"], default="byte")
|
| 735 |
+
parser.add_argument("--tokenizer-path", default="")
|
| 736 |
+
parser.add_argument("--max-records", type=int, default=None)
|
| 737 |
+
parser.add_argument("--max-tokens", type=int, default=1_000_000)
|
| 738 |
+
parser.add_argument("--eval-fraction", type=float, default=0.1)
|
| 739 |
+
parser.add_argument("--architectures", default="taonet,taonet_ssm")
|
| 740 |
+
parser.add_argument("--batch-sizes", default="8,16")
|
| 741 |
+
parser.add_argument("--seq-len", type=int, default=512)
|
| 742 |
+
parser.add_argument("--hidden-dim", type=int, default=256)
|
| 743 |
+
parser.add_argument("--num-layers", type=int, default=4)
|
| 744 |
+
parser.add_argument("--num-heads", type=int, default=4)
|
| 745 |
+
parser.add_argument("--d-latent-kv", type=int, default=None)
|
| 746 |
+
parser.add_argument("--d-rope", type=int, default=None)
|
| 747 |
+
parser.add_argument("--hidden-dim-ff", type=int, default=None)
|
| 748 |
+
parser.add_argument("--dropout", type=float, default=0.0)
|
| 749 |
+
parser.add_argument("--gqa-groups", type=int, default=1)
|
| 750 |
+
parser.add_argument("--rope-scale", type=float, default=40.0)
|
| 751 |
+
parser.add_argument("--yarn-alpha", type=float, default=1.0)
|
| 752 |
+
parser.add_argument("--init-std", type=float, default=0.02)
|
| 753 |
+
parser.add_argument("--ssm-core", choices=["gamma_s4", "dplr"], default="dplr")
|
| 754 |
+
parser.add_argument("--ssm-hidden-dims", default="16,64")
|
| 755 |
+
parser.add_argument("--ssm-mixer-dim", type=int, default=64)
|
| 756 |
+
parser.add_argument(
|
| 757 |
+
"--ssm-mixer-dims",
|
| 758 |
+
default="",
|
| 759 |
+
help="Comma-separated SSM mixer projection dimensions to sweep. Defaults to --ssm-mixer-dim.",
|
| 760 |
+
)
|
| 761 |
+
parser.add_argument("--ssm-num-lanes", type=int, default=1)
|
| 762 |
+
parser.add_argument(
|
| 763 |
+
"--ssm-num-lanes-list",
|
| 764 |
+
default="",
|
| 765 |
+
help="Comma-separated SSM lane counts to sweep for SSM-bearing architectures.",
|
| 766 |
+
)
|
| 767 |
+
parser.add_argument("--ssm-lane-combine", choices=["mean", "channel"], default="mean")
|
| 768 |
+
parser.add_argument("--ssm-lane-mode", choices=["full", "split"], default="full")
|
| 769 |
+
parser.add_argument(
|
| 770 |
+
"--ssm-lane-modes",
|
| 771 |
+
default="",
|
| 772 |
+
help="Comma-separated SSM lane modes to sweep for SSM-bearing architectures.",
|
| 773 |
+
)
|
| 774 |
+
parser.add_argument("--ssm-split-mix", choices=["none", "hadamard"], default="none")
|
| 775 |
+
parser.add_argument(
|
| 776 |
+
"--ssm-split-mixes",
|
| 777 |
+
default="",
|
| 778 |
+
help="Comma-separated split-lane cross mixers to sweep for SSM-bearing architectures.",
|
| 779 |
+
)
|
| 780 |
+
parser.add_argument("--ssm-rank", type=int, default=1)
|
| 781 |
+
parser.add_argument("--ssm-max-low-rank-scale", type=float, default=0.1)
|
| 782 |
+
parser.add_argument("--ssm-finite-tail-correction", action=argparse.BooleanOptionalAction, default=True)
|
| 783 |
+
parser.add_argument("--ssm-kernel-mode", choices=["auto", "conv", "conv_transfer", "recurrent"], default="conv")
|
| 784 |
+
parser.add_argument("--ssm-kernel-threshold", type=int, default=1)
|
| 785 |
+
parser.add_argument("--ssm-dt-min", type=float, default=1e-3)
|
| 786 |
+
parser.add_argument("--ssm-dt-max", type=float, default=1e-1)
|
| 787 |
+
parser.add_argument("--ssm-dt-init", type=float, default=1e-2)
|
| 788 |
+
parser.add_argument("--ssm-activation", choices=["gelu", "silu", "identity", "linear"], default="gelu")
|
| 789 |
+
parser.add_argument("--ssm-gate", action=argparse.BooleanOptionalAction, default=True)
|
| 790 |
+
parser.add_argument("--ssm-input-gate", action=argparse.BooleanOptionalAction, default=True)
|
| 791 |
+
parser.add_argument("--ssm-gate-type", choices=["dense", "channel"], default="dense")
|
| 792 |
+
parser.add_argument(
|
| 793 |
+
"--ssm-gate-types",
|
| 794 |
+
default="",
|
| 795 |
+
help="Comma-separated SSM gate implementations to sweep for SSM-bearing architectures.",
|
| 796 |
+
)
|
| 797 |
+
parser.add_argument("--ssm-layer-scale-init", type=float, default=0.1)
|
| 798 |
+
parser.add_argument("--ssm-branch-rms-norm", action=argparse.BooleanOptionalAction, default=False)
|
| 799 |
+
parser.add_argument("--ssm-branch-clip-value", type=float, default=None)
|
| 800 |
+
parser.add_argument("--block-residual-rms-norm", action=argparse.BooleanOptionalAction, default=False)
|
| 801 |
+
parser.add_argument("--block-residual-rms-target", type=float, default=1.0)
|
| 802 |
+
parser.add_argument("--block-residual-rms-cap", type=float, default=None)
|
| 803 |
+
parser.add_argument("--ssm-local-shift", action=argparse.BooleanOptionalAction, default=True)
|
| 804 |
+
parser.add_argument("--ssm-local-shift-init", type=float, default=0.1)
|
| 805 |
+
parser.add_argument("--ssm-local-shift-per-channel", action=argparse.BooleanOptionalAction, default=False)
|
| 806 |
+
parser.add_argument(
|
| 807 |
+
"--hybrid-pattern",
|
| 808 |
+
choices=["attention_first", "ssm_first", "single_ssm_middle", "single_ssm_late"],
|
| 809 |
+
default="attention_first",
|
| 810 |
+
)
|
| 811 |
+
parser.add_argument(
|
| 812 |
+
"--hybrid-patterns",
|
| 813 |
+
default="",
|
| 814 |
+
help="Comma-separated taonet_hybrid patterns to sweep. Ignored when --hybrid-ssm-layers is set.",
|
| 815 |
+
)
|
| 816 |
+
parser.add_argument(
|
| 817 |
+
"--hybrid-ssm-layers",
|
| 818 |
+
default="",
|
| 819 |
+
help="Comma-separated 0-based SSM layer indices for taonet_hybrid. Overrides pattern selection.",
|
| 820 |
+
)
|
| 821 |
+
parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16")
|
| 822 |
+
parser.add_argument("--device", default="auto")
|
| 823 |
+
parser.add_argument("--warmup", type=int, default=1)
|
| 824 |
+
parser.add_argument("--repeats", type=int, default=3)
|
| 825 |
+
parser.add_argument("--backward", action="store_true")
|
| 826 |
+
parser.add_argument("--train-steps", type=int, default=100)
|
| 827 |
+
parser.add_argument(
|
| 828 |
+
"--train-log-every",
|
| 829 |
+
type=int,
|
| 830 |
+
default=0,
|
| 831 |
+
help="Record train loss/grad history every N steps in CSV/JSON/checkpoints. 0 disables.",
|
| 832 |
+
)
|
| 833 |
+
parser.add_argument("--learning-rate", type=float, default=8e-4)
|
| 834 |
+
parser.add_argument(
|
| 835 |
+
"--max-grad-norm",
|
| 836 |
+
type=float,
|
| 837 |
+
default=0.0,
|
| 838 |
+
help="Clip training gradients to this norm during the benchmark. Use <=0 to disable.",
|
| 839 |
+
)
|
| 840 |
+
parser.add_argument(
|
| 841 |
+
"--learning-rates",
|
| 842 |
+
default="",
|
| 843 |
+
help="Comma-separated SSM learning rates to sweep. Attention uses --learning-rate once.",
|
| 844 |
+
)
|
| 845 |
+
parser.add_argument("--weight-decay", type=float, default=0.01)
|
| 846 |
+
parser.add_argument(
|
| 847 |
+
"--weight-decays",
|
| 848 |
+
default="",
|
| 849 |
+
help="Comma-separated SSM weight decays to sweep. Attention uses --weight-decay once.",
|
| 850 |
+
)
|
| 851 |
+
parser.add_argument("--eval-batches", type=int, default=8)
|
| 852 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 853 |
+
parser.add_argument("--output-dir", default=os.environ.get("REPOBRIDGE_OUTPUT_DIR", "results/real-token-bench"))
|
| 854 |
+
parser.add_argument(
|
| 855 |
+
"--resume-completed",
|
| 856 |
+
action="store_true",
|
| 857 |
+
help="Skip benchmark cases whose case_id is already present in the output CSV.",
|
| 858 |
+
)
|
| 859 |
+
parser.add_argument(
|
| 860 |
+
"--incremental-output",
|
| 861 |
+
action=argparse.BooleanOptionalAction,
|
| 862 |
+
default=True,
|
| 863 |
+
help="Rewrite CSV/JSON outputs after each completed case so detached jobs leave usable partial results.",
|
| 864 |
+
)
|
| 865 |
+
parser.add_argument(
|
| 866 |
+
"--save-case-checkpoints",
|
| 867 |
+
action="store_true",
|
| 868 |
+
help="Save a model checkpoint after each trained benchmark case.",
|
| 869 |
+
)
|
| 870 |
+
parser.add_argument(
|
| 871 |
+
"--checkpoint-dir",
|
| 872 |
+
default="",
|
| 873 |
+
help="Directory for per-case checkpoints. Defaults to <output-dir>/checkpoints.",
|
| 874 |
+
)
|
| 875 |
+
args = parser.parse_args()
|
| 876 |
+
|
| 877 |
+
random.seed(args.seed)
|
| 878 |
+
torch.manual_seed(args.seed)
|
| 879 |
+
if args.device == "auto":
|
| 880 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 881 |
+
else:
|
| 882 |
+
device = torch.device(args.device)
|
| 883 |
+
dtype = DTYPES[args.dtype]
|
| 884 |
+
if device.type == "cuda":
|
| 885 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 886 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 887 |
+
|
| 888 |
+
tokens, vocab_size = load_token_stream(args)
|
| 889 |
+
train_stream, eval_stream = split_stream(tokens, args.eval_fraction)
|
| 890 |
+
architectures = [item.strip() for item in args.architectures.split(",") if item.strip()]
|
| 891 |
+
ssm_hidden_dims = parse_int_list(args.ssm_hidden_dims)
|
| 892 |
+
ssm_mixer_dims = parse_int_list(args.ssm_mixer_dims) if args.ssm_mixer_dims else [args.ssm_mixer_dim]
|
| 893 |
+
ssm_num_lanes_options = parse_int_list(args.ssm_num_lanes_list) if args.ssm_num_lanes_list else [args.ssm_num_lanes]
|
| 894 |
+
ssm_lane_mode_options = parse_str_list(args.ssm_lane_modes) if args.ssm_lane_modes else [args.ssm_lane_mode]
|
| 895 |
+
ssm_split_mix_options = parse_str_list(args.ssm_split_mixes) if args.ssm_split_mixes else [args.ssm_split_mix]
|
| 896 |
+
ssm_learning_rates = parse_float_list(args.learning_rates) if args.learning_rates else [args.learning_rate]
|
| 897 |
+
ssm_weight_decays = parse_float_list(args.weight_decays) if args.weight_decays else [args.weight_decay]
|
| 898 |
+
hybrid_patterns = parse_str_list(args.hybrid_patterns) if args.hybrid_patterns else [args.hybrid_pattern]
|
| 899 |
+
ssm_gate_types = parse_str_list(args.ssm_gate_types) if args.ssm_gate_types else [args.ssm_gate_type]
|
| 900 |
+
|
| 901 |
+
output_dir = Path(args.output_dir)
|
| 902 |
+
rows: list[dict[str, Any]] = load_existing_rows(output_dir) if args.resume_completed else []
|
| 903 |
+
completed_case_ids = {
|
| 904 |
+
str(row.get("case_id"))
|
| 905 |
+
for row in rows
|
| 906 |
+
if row.get("mode") == "forward_backward" or not args.backward
|
| 907 |
+
}
|
| 908 |
+
for architecture in architectures:
|
| 909 |
+
hidden_options: list[int | None] = [None]
|
| 910 |
+
mixer_options: list[int | None] = [None]
|
| 911 |
+
learning_rate_options = [args.learning_rate]
|
| 912 |
+
weight_decay_options = [args.weight_decay]
|
| 913 |
+
hybrid_pattern_options: list[str | None] = [None]
|
| 914 |
+
gate_type_options: list[str | None] = [None]
|
| 915 |
+
lane_options: list[int | None] = [None]
|
| 916 |
+
lane_mode_options: list[str | None] = [None]
|
| 917 |
+
split_mix_options: list[str | None] = ["none"]
|
| 918 |
+
if architecture in {"taonet_ssm", "taonet_hybrid"}:
|
| 919 |
+
hidden_options = ssm_hidden_dims
|
| 920 |
+
mixer_options = ssm_mixer_dims
|
| 921 |
+
learning_rate_options = ssm_learning_rates
|
| 922 |
+
weight_decay_options = ssm_weight_decays
|
| 923 |
+
gate_type_options = ssm_gate_types
|
| 924 |
+
lane_options = ssm_num_lanes_options
|
| 925 |
+
lane_mode_options = ssm_lane_mode_options
|
| 926 |
+
split_mix_options = ssm_split_mix_options
|
| 927 |
+
if architecture == "taonet_hybrid":
|
| 928 |
+
hybrid_pattern_options = [args.hybrid_pattern] if args.hybrid_ssm_layers else hybrid_patterns
|
| 929 |
+
for ssm_hidden_dim in hidden_options:
|
| 930 |
+
for ssm_mixer_dim in mixer_options:
|
| 931 |
+
for ssm_num_lanes in lane_options:
|
| 932 |
+
for ssm_lane_mode in lane_mode_options:
|
| 933 |
+
for ssm_split_mix in split_mix_options:
|
| 934 |
+
if (
|
| 935 |
+
architecture in {"taonet_ssm", "taonet_hybrid"}
|
| 936 |
+
and ssm_lane_mode == "split"
|
| 937 |
+
and (ssm_num_lanes is None or ssm_num_lanes <= 1)
|
| 938 |
+
):
|
| 939 |
+
continue
|
| 940 |
+
if (
|
| 941 |
+
architecture in {"taonet_ssm", "taonet_hybrid"}
|
| 942 |
+
and ssm_lane_mode == "split"
|
| 943 |
+
and ssm_mixer_dim is not None
|
| 944 |
+
and ssm_num_lanes is not None
|
| 945 |
+
and ssm_mixer_dim % ssm_num_lanes != 0
|
| 946 |
+
):
|
| 947 |
+
print(
|
| 948 |
+
"Skipping split-lane case because ssm_mixer_dim is not divisible by "
|
| 949 |
+
f"ssm_num_lanes: mixer={ssm_mixer_dim} lanes={ssm_num_lanes}"
|
| 950 |
+
)
|
| 951 |
+
continue
|
| 952 |
+
if (
|
| 953 |
+
architecture in {"taonet_ssm", "taonet_hybrid"}
|
| 954 |
+
and ssm_lane_mode != "split"
|
| 955 |
+
and ssm_split_mix != "none"
|
| 956 |
+
):
|
| 957 |
+
continue
|
| 958 |
+
if (
|
| 959 |
+
architecture in {"taonet_ssm", "taonet_hybrid"}
|
| 960 |
+
and ssm_split_mix == "hadamard"
|
| 961 |
+
and ssm_num_lanes != 2
|
| 962 |
+
):
|
| 963 |
+
continue
|
| 964 |
+
args.ssm_lane_mode = ssm_lane_mode or args.ssm_lane_mode
|
| 965 |
+
args.ssm_split_mix = ssm_split_mix or args.ssm_split_mix
|
| 966 |
+
for hybrid_pattern in hybrid_pattern_options:
|
| 967 |
+
for ssm_gate_type in gate_type_options:
|
| 968 |
+
for learning_rate in learning_rate_options:
|
| 969 |
+
for weight_decay in weight_decay_options:
|
| 970 |
+
for batch_size in parse_int_list(args.batch_sizes):
|
| 971 |
+
print(
|
| 972 |
+
"Benchmarking "
|
| 973 |
+
f"architecture={architecture} ssm_hidden_dim={ssm_hidden_dim} "
|
| 974 |
+
f"ssm_mixer_dim={ssm_mixer_dim} ssm_num_lanes={ssm_num_lanes} "
|
| 975 |
+
f"ssm_lane_mode={args.ssm_lane_mode} ssm_split_mix={args.ssm_split_mix} "
|
| 976 |
+
f"hybrid_pattern={hybrid_pattern} ssm_gate_type={ssm_gate_type} "
|
| 977 |
+
f"lr={learning_rate} wd={weight_decay} batch={batch_size} seq={args.seq_len}"
|
| 978 |
+
)
|
| 979 |
+
case_parts = make_case_parts(
|
| 980 |
+
args,
|
| 981 |
+
architecture=architecture,
|
| 982 |
+
ssm_hidden_dim=ssm_hidden_dim,
|
| 983 |
+
ssm_mixer_dim=ssm_mixer_dim,
|
| 984 |
+
ssm_num_lanes=ssm_num_lanes,
|
| 985 |
+
hybrid_pattern=hybrid_pattern,
|
| 986 |
+
ssm_gate_type=ssm_gate_type,
|
| 987 |
+
learning_rate=learning_rate,
|
| 988 |
+
weight_decay=weight_decay,
|
| 989 |
+
batch_size=batch_size,
|
| 990 |
+
)
|
| 991 |
+
case_id = stable_case_id(case_parts)
|
| 992 |
+
if args.resume_completed and case_id in completed_case_ids:
|
| 993 |
+
print(f"Skipping completed case_id={case_id}")
|
| 994 |
+
continue
|
| 995 |
+
case_rows = benchmark_case(
|
| 996 |
+
args=args,
|
| 997 |
+
architecture=architecture,
|
| 998 |
+
ssm_hidden_dim=ssm_hidden_dim,
|
| 999 |
+
ssm_mixer_dim=ssm_mixer_dim,
|
| 1000 |
+
ssm_num_lanes=ssm_num_lanes,
|
| 1001 |
+
hybrid_pattern=hybrid_pattern,
|
| 1002 |
+
ssm_gate_type=ssm_gate_type,
|
| 1003 |
+
learning_rate=learning_rate,
|
| 1004 |
+
weight_decay=weight_decay,
|
| 1005 |
+
vocab_size=vocab_size,
|
| 1006 |
+
train_stream=train_stream,
|
| 1007 |
+
eval_stream=eval_stream,
|
| 1008 |
+
batch_size=batch_size,
|
| 1009 |
+
dtype=dtype,
|
| 1010 |
+
device=device,
|
| 1011 |
+
)
|
| 1012 |
+
rows.extend(case_rows)
|
| 1013 |
+
completed_case_ids.add(case_id)
|
| 1014 |
+
if args.incremental_output:
|
| 1015 |
+
write_outputs(rows, output_dir, {})
|
| 1016 |
+
|
| 1017 |
+
metadata = {
|
| 1018 |
+
"python": platform.python_version(),
|
| 1019 |
+
"platform": platform.platform(),
|
| 1020 |
+
"torch": torch.__version__,
|
| 1021 |
+
"cuda_available": torch.cuda.is_available(),
|
| 1022 |
+
"cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None,
|
| 1023 |
+
"args": vars(args),
|
| 1024 |
+
"vocab_size": vocab_size,
|
| 1025 |
+
"num_tokens": int(tokens.numel()),
|
| 1026 |
+
"train_tokens": int(train_stream.numel()),
|
| 1027 |
+
"eval_tokens": int(eval_stream.numel()),
|
| 1028 |
+
}
|
| 1029 |
+
print_table(rows)
|
| 1030 |
+
write_outputs(rows, output_dir, metadata)
|
| 1031 |
+
|
| 1032 |
+
|
| 1033 |
+
if __name__ == "__main__":
|
| 1034 |
+
main()
|
code/TaoTrain/scripts/benchmark_taonet_token_variants.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Token-level benchmark for TaoNet attention vs TaoNet-SSM.
|
| 2 |
+
|
| 3 |
+
The goal is to compare the two LLM wrappers with the same outer dimensions:
|
| 4 |
+
original MLA attention TaoNet versus TaoNet with an SSM mixer.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
from contextlib import nullcontext
|
| 11 |
+
from contextlib import redirect_stdout
|
| 12 |
+
import csv
|
| 13 |
+
import io
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
import platform
|
| 18 |
+
import subprocess
|
| 19 |
+
import sys
|
| 20 |
+
import time
|
| 21 |
+
from typing import Any
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 26 |
+
SRC_ROOT = REPO_ROOT / "src"
|
| 27 |
+
if str(SRC_ROOT) not in sys.path:
|
| 28 |
+
sys.path.insert(0, str(SRC_ROOT))
|
| 29 |
+
|
| 30 |
+
from taoTrain.config import ModelConfig
|
| 31 |
+
from taoTrain.models import get_model
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
DTYPES = {
|
| 35 |
+
"float32": torch.float32,
|
| 36 |
+
"fp32": torch.float32,
|
| 37 |
+
"float16": torch.float16,
|
| 38 |
+
"fp16": torch.float16,
|
| 39 |
+
"bfloat16": torch.bfloat16,
|
| 40 |
+
"bf16": torch.bfloat16,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def parse_int_list(value: str) -> list[int]:
|
| 45 |
+
return [int(item.strip()) for item in value.split(",") if item.strip()]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def synchronize(device: torch.device) -> None:
|
| 49 |
+
if device.type == "cuda":
|
| 50 |
+
torch.cuda.synchronize(device)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def reset_memory(device: torch.device) -> None:
|
| 54 |
+
if device.type == "cuda":
|
| 55 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def memory_stats(device: torch.device) -> dict[str, float | None]:
|
| 59 |
+
if device.type != "cuda":
|
| 60 |
+
return {
|
| 61 |
+
"peak_allocated_mb": None,
|
| 62 |
+
"peak_reserved_mb": None,
|
| 63 |
+
}
|
| 64 |
+
return {
|
| 65 |
+
"peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
|
| 66 |
+
"peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def nvidia_smi_snapshot() -> str | None:
|
| 71 |
+
try:
|
| 72 |
+
completed = subprocess.run(
|
| 73 |
+
[
|
| 74 |
+
"nvidia-smi",
|
| 75 |
+
"--query-gpu=name,memory.used,memory.total,utilization.gpu,utilization.memory,power.draw,temperature.gpu",
|
| 76 |
+
"--format=csv,noheader,nounits",
|
| 77 |
+
],
|
| 78 |
+
check=False,
|
| 79 |
+
capture_output=True,
|
| 80 |
+
text=True,
|
| 81 |
+
timeout=5,
|
| 82 |
+
)
|
| 83 |
+
except (OSError, subprocess.TimeoutExpired):
|
| 84 |
+
return None
|
| 85 |
+
if completed.returncode != 0:
|
| 86 |
+
return None
|
| 87 |
+
return completed.stdout.strip()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def make_token_batch(
|
| 91 |
+
*,
|
| 92 |
+
batch_size: int,
|
| 93 |
+
seq_len: int,
|
| 94 |
+
vocab_size: int,
|
| 95 |
+
device: torch.device,
|
| 96 |
+
task: str = "random",
|
| 97 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 98 |
+
if task == "random":
|
| 99 |
+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
| 100 |
+
labels = torch.empty_like(input_ids)
|
| 101 |
+
labels[:, :-1] = input_ids[:, 1:]
|
| 102 |
+
labels[:, -1] = torch.randint(0, vocab_size, (batch_size,), device=device)
|
| 103 |
+
elif task == "increment":
|
| 104 |
+
starts = torch.randint(0, vocab_size, (batch_size, 1), device=device)
|
| 105 |
+
offsets = torch.arange(seq_len, device=device).view(1, seq_len)
|
| 106 |
+
input_ids = (starts + offsets) % vocab_size
|
| 107 |
+
labels = (input_ids + 1) % vocab_size
|
| 108 |
+
elif task == "previous":
|
| 109 |
+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
| 110 |
+
labels = torch.empty_like(input_ids)
|
| 111 |
+
labels[:, 0] = -100
|
| 112 |
+
labels[:, 1:] = input_ids[:, :-1]
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError(f"Unsupported token task '{task}'.")
|
| 115 |
+
attention_mask = torch.ones_like(input_ids)
|
| 116 |
+
return input_ids, labels, attention_mask
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def token_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float:
|
| 120 |
+
valid = labels != -100
|
| 121 |
+
if not torch.any(valid):
|
| 122 |
+
return float("nan")
|
| 123 |
+
predictions = torch.argmax(logits, dim=-1)
|
| 124 |
+
correct = (predictions == labels) & valid
|
| 125 |
+
return float(correct.sum().detach().cpu() / valid.sum().detach().cpu())
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def build_config(args: argparse.Namespace, architecture: str) -> ModelConfig:
|
| 129 |
+
d_latent_kv = args.d_latent_kv if args.d_latent_kv is not None else int(args.hidden_dim * 0.75)
|
| 130 |
+
d_rope = args.d_rope if args.d_rope is not None else args.hidden_dim // args.num_heads
|
| 131 |
+
hidden_dim_ff = args.hidden_dim_ff if args.hidden_dim_ff is not None else args.hidden_dim * 4
|
| 132 |
+
return ModelConfig(
|
| 133 |
+
architecture_type=architecture,
|
| 134 |
+
vocab_size=args.vocab_size,
|
| 135 |
+
hidden_dim=args.hidden_dim,
|
| 136 |
+
num_layers=args.num_layers,
|
| 137 |
+
num_heads=args.num_heads,
|
| 138 |
+
max_seq_length=max(parse_int_list(args.seq_lens)),
|
| 139 |
+
d_latent_kv=d_latent_kv,
|
| 140 |
+
d_rope=d_rope,
|
| 141 |
+
hidden_dim_ff=hidden_dim_ff,
|
| 142 |
+
dropout=args.dropout,
|
| 143 |
+
gqa_groups=args.gqa_groups,
|
| 144 |
+
rope_scale=args.rope_scale,
|
| 145 |
+
yarn_alpha=args.yarn_alpha,
|
| 146 |
+
init_std=args.init_std,
|
| 147 |
+
ssm_core=args.ssm_core,
|
| 148 |
+
ssm_hidden_dim=args.ssm_hidden_dim or d_latent_kv,
|
| 149 |
+
ssm_mixer_dim=args.ssm_mixer_dim,
|
| 150 |
+
ssm_rank=args.ssm_rank,
|
| 151 |
+
ssm_max_low_rank_scale=args.ssm_max_low_rank_scale,
|
| 152 |
+
ssm_kernel_mode=args.ssm_kernel_mode,
|
| 153 |
+
ssm_kernel_threshold=args.ssm_kernel_threshold,
|
| 154 |
+
ssm_dt_min=args.ssm_dt_min,
|
| 155 |
+
ssm_dt_max=args.ssm_dt_max,
|
| 156 |
+
ssm_dt_init=args.ssm_dt_init,
|
| 157 |
+
ssm_use_padding_mask=args.ssm_use_padding_mask,
|
| 158 |
+
ssm_activation=args.ssm_activation,
|
| 159 |
+
ssm_gate=args.ssm_gate,
|
| 160 |
+
ssm_input_gate=args.ssm_input_gate,
|
| 161 |
+
ssm_layer_scale_init=args.ssm_layer_scale_init,
|
| 162 |
+
ssm_local_shift=args.ssm_local_shift,
|
| 163 |
+
ssm_local_shift_init=args.ssm_local_shift_init,
|
| 164 |
+
ssm_local_shift_per_channel=args.ssm_local_shift_per_channel,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def count_params(model: torch.nn.Module) -> tuple[int, int]:
|
| 169 |
+
total = sum(param.numel() for param in model.parameters())
|
| 170 |
+
trainable = sum(param.numel() for param in model.parameters() if param.requires_grad)
|
| 171 |
+
return total, trainable
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def time_repeats(fn, *, device: torch.device, warmup: int, repeats: int) -> tuple[float, float, float]:
|
| 175 |
+
last_loss = float("nan")
|
| 176 |
+
for _ in range(warmup):
|
| 177 |
+
last_loss = fn()
|
| 178 |
+
synchronize(device)
|
| 179 |
+
|
| 180 |
+
latencies = []
|
| 181 |
+
for _ in range(repeats):
|
| 182 |
+
reset_memory(device)
|
| 183 |
+
synchronize(device)
|
| 184 |
+
start = time.perf_counter()
|
| 185 |
+
last_loss = fn()
|
| 186 |
+
synchronize(device)
|
| 187 |
+
latencies.append(time.perf_counter() - start)
|
| 188 |
+
return sum(latencies) / len(latencies), min(latencies), last_loss
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def evaluate_model(
|
| 192 |
+
model: torch.nn.Module,
|
| 193 |
+
*,
|
| 194 |
+
args: argparse.Namespace,
|
| 195 |
+
batch_size: int,
|
| 196 |
+
seq_len: int,
|
| 197 |
+
device: torch.device,
|
| 198 |
+
autocast_context,
|
| 199 |
+
) -> tuple[float, float]:
|
| 200 |
+
model.eval()
|
| 201 |
+
losses = []
|
| 202 |
+
accuracies = []
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
for _ in range(args.eval_batches):
|
| 205 |
+
input_ids, labels, attention_mask = make_token_batch(
|
| 206 |
+
batch_size=batch_size,
|
| 207 |
+
seq_len=seq_len,
|
| 208 |
+
vocab_size=args.vocab_size,
|
| 209 |
+
device=device,
|
| 210 |
+
task=args.token_task,
|
| 211 |
+
)
|
| 212 |
+
with autocast_context():
|
| 213 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 214 |
+
losses.append(float(outputs["loss"].detach().cpu()))
|
| 215 |
+
accuracies.append(token_accuracy(outputs["logits"], labels))
|
| 216 |
+
model.train()
|
| 217 |
+
return sum(losses) / len(losses), sum(accuracies) / len(accuracies)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def train_model(
|
| 221 |
+
model: torch.nn.Module,
|
| 222 |
+
*,
|
| 223 |
+
args: argparse.Namespace,
|
| 224 |
+
batch_size: int,
|
| 225 |
+
seq_len: int,
|
| 226 |
+
device: torch.device,
|
| 227 |
+
autocast_context,
|
| 228 |
+
) -> tuple[float | None, float | None]:
|
| 229 |
+
if args.train_steps <= 0:
|
| 230 |
+
return None, None
|
| 231 |
+
|
| 232 |
+
model.train()
|
| 233 |
+
optimizer = torch.optim.AdamW(
|
| 234 |
+
model.parameters(),
|
| 235 |
+
lr=args.learning_rate,
|
| 236 |
+
weight_decay=args.weight_decay,
|
| 237 |
+
)
|
| 238 |
+
last_loss = float("nan")
|
| 239 |
+
start = time.perf_counter()
|
| 240 |
+
for _ in range(args.train_steps):
|
| 241 |
+
input_ids, labels, attention_mask = make_token_batch(
|
| 242 |
+
batch_size=batch_size,
|
| 243 |
+
seq_len=seq_len,
|
| 244 |
+
vocab_size=args.vocab_size,
|
| 245 |
+
device=device,
|
| 246 |
+
task=args.token_task,
|
| 247 |
+
)
|
| 248 |
+
optimizer.zero_grad(set_to_none=True)
|
| 249 |
+
with autocast_context():
|
| 250 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 251 |
+
loss = outputs["loss"]
|
| 252 |
+
loss.backward()
|
| 253 |
+
optimizer.step()
|
| 254 |
+
last_loss = float(loss.detach().cpu())
|
| 255 |
+
synchronize(device)
|
| 256 |
+
return last_loss, time.perf_counter() - start
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def benchmark_case(
|
| 260 |
+
*,
|
| 261 |
+
args: argparse.Namespace,
|
| 262 |
+
architecture: str,
|
| 263 |
+
batch_size: int,
|
| 264 |
+
seq_len: int,
|
| 265 |
+
dtype: torch.dtype,
|
| 266 |
+
device: torch.device,
|
| 267 |
+
) -> list[dict[str, Any]]:
|
| 268 |
+
config = build_config(args, architecture)
|
| 269 |
+
with redirect_stdout(io.StringIO()):
|
| 270 |
+
model = get_model(config, device=device)
|
| 271 |
+
model.train()
|
| 272 |
+
total_params, trainable_params = count_params(model)
|
| 273 |
+
tokens = batch_size * seq_len
|
| 274 |
+
input_ids, labels, attention_mask = make_token_batch(
|
| 275 |
+
batch_size=batch_size,
|
| 276 |
+
seq_len=seq_len,
|
| 277 |
+
vocab_size=args.vocab_size,
|
| 278 |
+
device=device,
|
| 279 |
+
task=args.token_task,
|
| 280 |
+
)
|
| 281 |
+
autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
|
| 282 |
+
|
| 283 |
+
def autocast_context():
|
| 284 |
+
if not autocast_enabled:
|
| 285 |
+
return nullcontext()
|
| 286 |
+
return torch.autocast(device_type=device.type, dtype=dtype, enabled=True)
|
| 287 |
+
|
| 288 |
+
train_final_loss, train_seconds = train_model(
|
| 289 |
+
model,
|
| 290 |
+
args=args,
|
| 291 |
+
batch_size=batch_size,
|
| 292 |
+
seq_len=seq_len,
|
| 293 |
+
device=device,
|
| 294 |
+
autocast_context=autocast_context,
|
| 295 |
+
)
|
| 296 |
+
eval_loss, eval_accuracy = evaluate_model(
|
| 297 |
+
model,
|
| 298 |
+
args=args,
|
| 299 |
+
batch_size=batch_size,
|
| 300 |
+
seq_len=seq_len,
|
| 301 |
+
device=device,
|
| 302 |
+
autocast_context=autocast_context,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
rows: list[dict[str, Any]] = []
|
| 306 |
+
|
| 307 |
+
def add_row(mode: str, mean_s: float, min_s: float, loss: float) -> None:
|
| 308 |
+
rows.append(
|
| 309 |
+
{
|
| 310 |
+
"architecture": architecture,
|
| 311 |
+
"ssm_core": args.ssm_core if architecture == "taonet_ssm" else None,
|
| 312 |
+
"token_task": args.token_task,
|
| 313 |
+
"train_steps": args.train_steps,
|
| 314 |
+
"mode": mode,
|
| 315 |
+
"batch_size": batch_size,
|
| 316 |
+
"seq_len": seq_len,
|
| 317 |
+
"tokens": tokens,
|
| 318 |
+
"vocab_size": args.vocab_size,
|
| 319 |
+
"hidden_dim": args.hidden_dim,
|
| 320 |
+
"num_layers": args.num_layers,
|
| 321 |
+
"num_heads": args.num_heads,
|
| 322 |
+
"d_latent_kv": config.d_latent_kv,
|
| 323 |
+
"ssm_hidden_dim": config.ssm_hidden_dim if architecture == "taonet_ssm" else None,
|
| 324 |
+
"ssm_mixer_dim": config.ssm_mixer_dim if architecture == "taonet_ssm" else None,
|
| 325 |
+
"ssm_rank": config.ssm_rank if architecture == "taonet_ssm" else None,
|
| 326 |
+
"ssm_local_shift": config.ssm_local_shift if architecture == "taonet_ssm" else None,
|
| 327 |
+
"ssm_local_shift_init": config.ssm_local_shift_init if architecture == "taonet_ssm" else None,
|
| 328 |
+
"ssm_local_shift_per_channel": config.ssm_local_shift_per_channel if architecture == "taonet_ssm" else None,
|
| 329 |
+
"dtype": str(dtype).replace("torch.", ""),
|
| 330 |
+
"device": str(device),
|
| 331 |
+
"total_params": total_params,
|
| 332 |
+
"trainable_params": trainable_params,
|
| 333 |
+
"mean_ms": mean_s * 1000.0,
|
| 334 |
+
"min_ms": min_s * 1000.0,
|
| 335 |
+
"tokens_per_s_mean": tokens / max(mean_s, 1e-12),
|
| 336 |
+
"tokens_per_s_best": tokens / max(min_s, 1e-12),
|
| 337 |
+
"loss": loss,
|
| 338 |
+
"eval_loss": eval_loss,
|
| 339 |
+
"eval_accuracy": eval_accuracy,
|
| 340 |
+
"train_final_loss": train_final_loss,
|
| 341 |
+
"train_seconds": train_seconds,
|
| 342 |
+
**memory_stats(device),
|
| 343 |
+
}
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
def forward_only() -> float:
|
| 347 |
+
with torch.no_grad():
|
| 348 |
+
with autocast_context():
|
| 349 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 350 |
+
loss = outputs["loss"]
|
| 351 |
+
return float(loss.detach().cpu())
|
| 352 |
+
|
| 353 |
+
mean_s, min_s, loss = time_repeats(
|
| 354 |
+
forward_only,
|
| 355 |
+
device=device,
|
| 356 |
+
warmup=args.warmup,
|
| 357 |
+
repeats=args.repeats,
|
| 358 |
+
)
|
| 359 |
+
add_row("forward", mean_s, min_s, loss)
|
| 360 |
+
|
| 361 |
+
if args.backward:
|
| 362 |
+
def forward_backward() -> float:
|
| 363 |
+
model.zero_grad(set_to_none=True)
|
| 364 |
+
with autocast_context():
|
| 365 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 366 |
+
loss = outputs["loss"]
|
| 367 |
+
loss.backward()
|
| 368 |
+
return float(loss.detach().cpu())
|
| 369 |
+
|
| 370 |
+
mean_s, min_s, loss = time_repeats(
|
| 371 |
+
forward_backward,
|
| 372 |
+
device=device,
|
| 373 |
+
warmup=args.warmup,
|
| 374 |
+
repeats=args.repeats,
|
| 375 |
+
)
|
| 376 |
+
add_row("forward_backward", mean_s, min_s, loss)
|
| 377 |
+
|
| 378 |
+
return rows
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def print_table(rows: list[dict[str, Any]]) -> None:
|
| 382 |
+
columns = [
|
| 383 |
+
"architecture",
|
| 384 |
+
"ssm_core",
|
| 385 |
+
"token_task",
|
| 386 |
+
"mode",
|
| 387 |
+
"batch_size",
|
| 388 |
+
"seq_len",
|
| 389 |
+
"mean_ms",
|
| 390 |
+
"tokens_per_s_mean",
|
| 391 |
+
"peak_allocated_mb",
|
| 392 |
+
"loss",
|
| 393 |
+
"eval_loss",
|
| 394 |
+
"eval_accuracy",
|
| 395 |
+
]
|
| 396 |
+
print("\t".join(columns))
|
| 397 |
+
for row in rows:
|
| 398 |
+
values = []
|
| 399 |
+
for column in columns:
|
| 400 |
+
value = row[column]
|
| 401 |
+
if isinstance(value, float):
|
| 402 |
+
values.append(f"{value:.3f}")
|
| 403 |
+
else:
|
| 404 |
+
values.append(str(value))
|
| 405 |
+
print("\t".join(values))
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def write_outputs(rows: list[dict[str, Any]], output_dir: Path, metadata: dict[str, Any]) -> None:
|
| 409 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 410 |
+
json_path = output_dir / "taonet_token_benchmark.json"
|
| 411 |
+
csv_path = output_dir / "taonet_token_benchmark.csv"
|
| 412 |
+
json_path.write_text(json.dumps({"metadata": metadata, "results": rows}, indent=2), encoding="utf-8")
|
| 413 |
+
|
| 414 |
+
fieldnames = list(rows[0].keys()) if rows else []
|
| 415 |
+
with csv_path.open("w", newline="", encoding="utf-8") as handle:
|
| 416 |
+
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
| 417 |
+
writer.writeheader()
|
| 418 |
+
writer.writerows(rows)
|
| 419 |
+
|
| 420 |
+
print(f"Wrote {json_path}")
|
| 421 |
+
print(f"Wrote {csv_path}")
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def main() -> None:
|
| 425 |
+
parser = argparse.ArgumentParser(description="Benchmark TaoNet attention vs TaoNet-SSM on token batches.")
|
| 426 |
+
parser.add_argument("--architectures", default="taonet,taonet_ssm")
|
| 427 |
+
parser.add_argument("--batch-sizes", default="1,4")
|
| 428 |
+
parser.add_argument("--seq-lens", default="128,512")
|
| 429 |
+
parser.add_argument("--vocab-size", type=int, default=8192)
|
| 430 |
+
parser.add_argument("--hidden-dim", type=int, default=256)
|
| 431 |
+
parser.add_argument("--num-layers", type=int, default=4)
|
| 432 |
+
parser.add_argument("--num-heads", type=int, default=4)
|
| 433 |
+
parser.add_argument("--d-latent-kv", type=int, default=None)
|
| 434 |
+
parser.add_argument("--d-rope", type=int, default=None)
|
| 435 |
+
parser.add_argument("--hidden-dim-ff", type=int, default=None)
|
| 436 |
+
parser.add_argument("--dropout", type=float, default=0.0)
|
| 437 |
+
parser.add_argument("--gqa-groups", type=int, default=1)
|
| 438 |
+
parser.add_argument("--rope-scale", type=float, default=40.0)
|
| 439 |
+
parser.add_argument("--yarn-alpha", type=float, default=1.0)
|
| 440 |
+
parser.add_argument("--init-std", type=float, default=0.02)
|
| 441 |
+
parser.add_argument("--ssm-core", choices=["gamma_s4", "dplr"], default="dplr")
|
| 442 |
+
parser.add_argument("--ssm-hidden-dim", type=int, default=None)
|
| 443 |
+
parser.add_argument("--ssm-mixer-dim", type=int, default=None)
|
| 444 |
+
parser.add_argument("--ssm-rank", type=int, default=1)
|
| 445 |
+
parser.add_argument("--ssm-max-low-rank-scale", type=float, default=0.1)
|
| 446 |
+
parser.add_argument("--ssm-kernel-mode", choices=["auto", "conv", "conv_transfer", "recurrent"], default="conv")
|
| 447 |
+
parser.add_argument("--ssm-kernel-threshold", type=int, default=1)
|
| 448 |
+
parser.add_argument("--ssm-dt-min", type=float, default=1e-3)
|
| 449 |
+
parser.add_argument("--ssm-dt-max", type=float, default=1e-1)
|
| 450 |
+
parser.add_argument("--ssm-dt-init", type=float, default=1e-2)
|
| 451 |
+
parser.add_argument("--ssm-use-padding-mask", action="store_true")
|
| 452 |
+
parser.add_argument("--ssm-activation", choices=["gelu", "silu", "identity", "linear"], default="gelu")
|
| 453 |
+
parser.add_argument("--ssm-gate", action=argparse.BooleanOptionalAction, default=True)
|
| 454 |
+
parser.add_argument("--ssm-input-gate", action=argparse.BooleanOptionalAction, default=True)
|
| 455 |
+
parser.add_argument("--ssm-layer-scale-init", type=float, default=0.1)
|
| 456 |
+
parser.add_argument("--ssm-local-shift", action=argparse.BooleanOptionalAction, default=False)
|
| 457 |
+
parser.add_argument("--ssm-local-shift-init", type=float, default=0.1)
|
| 458 |
+
parser.add_argument("--ssm-local-shift-per-channel", action=argparse.BooleanOptionalAction, default=False)
|
| 459 |
+
parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16")
|
| 460 |
+
parser.add_argument("--device", default="auto")
|
| 461 |
+
parser.add_argument("--warmup", type=int, default=2)
|
| 462 |
+
parser.add_argument("--repeats", type=int, default=5)
|
| 463 |
+
parser.add_argument("--backward", action="store_true")
|
| 464 |
+
parser.add_argument("--token-task", choices=["random", "increment", "previous"], default="random")
|
| 465 |
+
parser.add_argument("--train-steps", type=int, default=0)
|
| 466 |
+
parser.add_argument("--learning-rate", type=float, default=3e-4)
|
| 467 |
+
parser.add_argument("--weight-decay", type=float, default=0.01)
|
| 468 |
+
parser.add_argument("--eval-batches", type=int, default=1)
|
| 469 |
+
parser.add_argument("--output-dir", default=os.environ.get("REPOBRIDGE_OUTPUT_DIR", "results/token-bench"))
|
| 470 |
+
args = parser.parse_args()
|
| 471 |
+
|
| 472 |
+
if args.device == "auto":
|
| 473 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 474 |
+
else:
|
| 475 |
+
device = torch.device(args.device)
|
| 476 |
+
dtype = DTYPES[args.dtype]
|
| 477 |
+
if device.type != "cuda" and dtype == torch.float16:
|
| 478 |
+
raise ValueError("float16 benchmark requires CUDA.")
|
| 479 |
+
if device.type == "cuda":
|
| 480 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 481 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 482 |
+
|
| 483 |
+
architectures = [item.strip() for item in args.architectures.split(",") if item.strip()]
|
| 484 |
+
rows: list[dict[str, Any]] = []
|
| 485 |
+
metadata = {
|
| 486 |
+
"python": platform.python_version(),
|
| 487 |
+
"platform": platform.platform(),
|
| 488 |
+
"torch": torch.__version__,
|
| 489 |
+
"cuda_available": torch.cuda.is_available(),
|
| 490 |
+
"cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None,
|
| 491 |
+
"nvidia_smi_before": nvidia_smi_snapshot(),
|
| 492 |
+
"args": vars(args),
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
for architecture in architectures:
|
| 496 |
+
for batch_size in parse_int_list(args.batch_sizes):
|
| 497 |
+
for seq_len in parse_int_list(args.seq_lens):
|
| 498 |
+
print(f"Benchmarking architecture={architecture} batch={batch_size} seq={seq_len}")
|
| 499 |
+
rows.extend(
|
| 500 |
+
benchmark_case(
|
| 501 |
+
args=args,
|
| 502 |
+
architecture=architecture,
|
| 503 |
+
batch_size=batch_size,
|
| 504 |
+
seq_len=seq_len,
|
| 505 |
+
dtype=dtype,
|
| 506 |
+
device=device,
|
| 507 |
+
)
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
metadata["nvidia_smi_after"] = nvidia_smi_snapshot()
|
| 511 |
+
print_table(rows)
|
| 512 |
+
write_outputs(rows, Path(args.output_dir), metadata)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
if __name__ == "__main__":
|
| 516 |
+
main()
|
code/TaoTrain/scripts/diagnostics/activation_probe.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Probe residual activation scale for a saved TaoTrain checkpoint."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 15 |
+
SRC_ROOT = REPO_ROOT / "src"
|
| 16 |
+
if str(SRC_ROOT) not in sys.path:
|
| 17 |
+
sys.path.insert(0, str(SRC_ROOT))
|
| 18 |
+
|
| 19 |
+
from taoTrain.checkpointing.checkpoint import CheckpointManager
|
| 20 |
+
from taoTrain.config import ModelConfig
|
| 21 |
+
from taoTrain.models import get_model
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_sentencepiece(path: Path):
|
| 25 |
+
import sentencepiece as spm
|
| 26 |
+
|
| 27 |
+
processor = spm.SentencePieceProcessor()
|
| 28 |
+
processor.load(str(path))
|
| 29 |
+
return processor
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_tokens(args: argparse.Namespace) -> tuple[torch.Tensor, int]:
|
| 33 |
+
tokenizer = load_sentencepiece(Path(args.tokenizer_path))
|
| 34 |
+
tokens: list[int] = []
|
| 35 |
+
with Path(args.data_path).open("r", encoding="utf-8", errors="replace") as handle:
|
| 36 |
+
for line in handle:
|
| 37 |
+
if len(tokens) >= args.max_tokens:
|
| 38 |
+
break
|
| 39 |
+
line = line.strip()
|
| 40 |
+
if not line:
|
| 41 |
+
continue
|
| 42 |
+
try:
|
| 43 |
+
record = json.loads(line)
|
| 44 |
+
except json.JSONDecodeError:
|
| 45 |
+
continue
|
| 46 |
+
text = record.get(args.text_field)
|
| 47 |
+
if not isinstance(text, str) or not text:
|
| 48 |
+
continue
|
| 49 |
+
ids = list(tokenizer.encode(text, out_type=int))
|
| 50 |
+
eos_id = tokenizer.eos_id()
|
| 51 |
+
if eos_id >= 0:
|
| 52 |
+
ids.append(eos_id)
|
| 53 |
+
tokens.extend(ids)
|
| 54 |
+
if len(tokens) < args.seq_len + 2:
|
| 55 |
+
raise ValueError(f"Need at least {args.seq_len + 2} tokens, got {len(tokens)}")
|
| 56 |
+
return torch.tensor(tokens[: args.max_tokens], dtype=torch.long), int(tokenizer.vocab_size())
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def sample_batch(tokens: torch.Tensor, *, batch_size: int, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
| 60 |
+
max_start = tokens.numel() - seq_len - 1
|
| 61 |
+
starts = torch.linspace(0, max_start - 1, steps=batch_size).long()
|
| 62 |
+
rows = [tokens[int(start) : int(start) + seq_len + 1] for start in starts]
|
| 63 |
+
batch = torch.stack(rows, dim=0).to(device=device)
|
| 64 |
+
return batch[:, :-1].contiguous(), batch[:, 1:].contiguous()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def tensor_stats(value: torch.Tensor) -> dict[str, float | int]:
|
| 68 |
+
data = value.detach().float()
|
| 69 |
+
finite = torch.isfinite(data)
|
| 70 |
+
finite_count = int(finite.sum().cpu())
|
| 71 |
+
numel = data.numel()
|
| 72 |
+
if finite_count:
|
| 73 |
+
finite_data = data[finite]
|
| 74 |
+
rms = float(torch.sqrt(torch.mean(finite_data * finite_data)).cpu())
|
| 75 |
+
max_abs = float(finite_data.abs().max().cpu())
|
| 76 |
+
else:
|
| 77 |
+
rms = float("inf")
|
| 78 |
+
max_abs = float("inf")
|
| 79 |
+
return {
|
| 80 |
+
"numel": numel,
|
| 81 |
+
"finite": finite_count,
|
| 82 |
+
"rms": rms,
|
| 83 |
+
"max_abs": max_abs,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def main() -> None:
|
| 88 |
+
parser = argparse.ArgumentParser()
|
| 89 |
+
parser.add_argument("--checkpoint", required=True)
|
| 90 |
+
parser.add_argument("--tokenizer-path", required=True)
|
| 91 |
+
parser.add_argument("--data-path", required=True)
|
| 92 |
+
parser.add_argument("--text-field", default="text")
|
| 93 |
+
parser.add_argument("--output", required=True)
|
| 94 |
+
parser.add_argument("--batch-size", type=int, default=2)
|
| 95 |
+
parser.add_argument("--seq-len", type=int, default=512)
|
| 96 |
+
parser.add_argument("--max-tokens", type=int, default=200_000)
|
| 97 |
+
parser.add_argument("--device", default="cuda")
|
| 98 |
+
parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16")
|
| 99 |
+
args = parser.parse_args()
|
| 100 |
+
|
| 101 |
+
device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu")
|
| 102 |
+
dtype = {
|
| 103 |
+
"float32": torch.float32,
|
| 104 |
+
"bfloat16": torch.bfloat16,
|
| 105 |
+
"float16": torch.float16,
|
| 106 |
+
}[args.dtype]
|
| 107 |
+
|
| 108 |
+
tokens, _ = load_tokens(args)
|
| 109 |
+
input_ids, labels = sample_batch(tokens, batch_size=args.batch_size, seq_len=args.seq_len, device=device)
|
| 110 |
+
attention_mask = torch.ones_like(input_ids)
|
| 111 |
+
|
| 112 |
+
checkpoint_path = Path(args.checkpoint)
|
| 113 |
+
checkpoint = CheckpointManager(checkpoint_path.parent).load(checkpoint_path, device=device)
|
| 114 |
+
config_dict = checkpoint.get("config", {})
|
| 115 |
+
model_config = ModelConfig(**config_dict.get("model", {}))
|
| 116 |
+
model = get_model(model_config, device=device)
|
| 117 |
+
model.load_state_dict(checkpoint["model_state"], strict=False)
|
| 118 |
+
model.eval()
|
| 119 |
+
|
| 120 |
+
layer_stats: dict[str, dict[str, float | int]] = {}
|
| 121 |
+
handles = []
|
| 122 |
+
layer_pattern = re.compile(r"^(?:model\.)?(?:layers|blocks)\.\d+$")
|
| 123 |
+
|
| 124 |
+
def make_hook(name: str):
|
| 125 |
+
def hook(_module, _inputs, output):
|
| 126 |
+
value = output[0] if isinstance(output, tuple) else output
|
| 127 |
+
if torch.is_tensor(value):
|
| 128 |
+
layer_stats[name] = tensor_stats(value)
|
| 129 |
+
|
| 130 |
+
return hook
|
| 131 |
+
|
| 132 |
+
for name, module in model.named_modules():
|
| 133 |
+
if layer_pattern.match(name):
|
| 134 |
+
handles.append(module.register_forward_hook(make_hook(name)))
|
| 135 |
+
|
| 136 |
+
device_type = "cuda" if device.type == "cuda" else "cpu"
|
| 137 |
+
autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
|
| 138 |
+
with torch.no_grad(), torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled):
|
| 139 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 140 |
+
|
| 141 |
+
for handle in handles:
|
| 142 |
+
handle.remove()
|
| 143 |
+
|
| 144 |
+
result: dict[str, Any] = {
|
| 145 |
+
"checkpoint": str(checkpoint_path),
|
| 146 |
+
"loss": float(outputs["loss"].detach().cpu()),
|
| 147 |
+
"batch_size": args.batch_size,
|
| 148 |
+
"seq_len": args.seq_len,
|
| 149 |
+
"device": str(device),
|
| 150 |
+
"dtype": str(dtype),
|
| 151 |
+
"layers": layer_stats,
|
| 152 |
+
}
|
| 153 |
+
output_path = Path(args.output)
|
| 154 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 155 |
+
output_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
| 156 |
+
print(json.dumps(result, indent=2))
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
main()
|
code/TaoTrain/scripts/diagnostics/generate_checkpoint_samples.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate a few text samples from a saved checkpoint."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 13 |
+
SRC_ROOT = REPO_ROOT / "src"
|
| 14 |
+
if str(SRC_ROOT) not in sys.path:
|
| 15 |
+
sys.path.insert(0, str(SRC_ROOT))
|
| 16 |
+
|
| 17 |
+
from taoTrain.checkpointing.checkpoint import CheckpointManager
|
| 18 |
+
from taoTrain.config import ModelConfig
|
| 19 |
+
from taoTrain.inference.inferencer import Inferencer
|
| 20 |
+
from taoTrain.models import get_model
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def clear_kernel_caches(model) -> None:
|
| 24 |
+
for module in model.modules():
|
| 25 |
+
clear = getattr(module, "clear_kernel_cache", None)
|
| 26 |
+
if callable(clear):
|
| 27 |
+
clear()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def generate_once(
|
| 31 |
+
model,
|
| 32 |
+
tokenizer,
|
| 33 |
+
prompt: str,
|
| 34 |
+
*,
|
| 35 |
+
device: torch.device,
|
| 36 |
+
max_new_tokens: int,
|
| 37 |
+
temperature: float,
|
| 38 |
+
top_p: float,
|
| 39 |
+
dtype: torch.dtype,
|
| 40 |
+
) -> str:
|
| 41 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
| 42 |
+
generated = []
|
| 43 |
+
eos_token_id = getattr(tokenizer, "eos_token_id", None)
|
| 44 |
+
model.eval()
|
| 45 |
+
device_type = "cuda" if device.type == "cuda" else "cpu"
|
| 46 |
+
autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
|
| 47 |
+
with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled):
|
| 48 |
+
for _ in range(max_new_tokens):
|
| 49 |
+
clear_kernel_caches(model)
|
| 50 |
+
outputs = model(input_ids=input_ids, attention_mask=torch.ones_like(input_ids), labels=None)
|
| 51 |
+
logits = outputs["logits"][:, -1, :] / max(temperature, 1e-6)
|
| 52 |
+
if top_p < 1.0:
|
| 53 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 54 |
+
sorted_probs = torch.softmax(sorted_logits, dim=-1)
|
| 55 |
+
cumulative = torch.cumsum(sorted_probs, dim=-1)
|
| 56 |
+
remove = cumulative > top_p
|
| 57 |
+
remove[..., 1:] = remove[..., :-1].clone()
|
| 58 |
+
remove[..., 0] = False
|
| 59 |
+
indices_to_remove = sorted_indices[remove]
|
| 60 |
+
logits[0, indices_to_remove] = float("-inf")
|
| 61 |
+
probs = torch.softmax(logits, dim=-1)
|
| 62 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 63 |
+
token_id = int(next_token.item())
|
| 64 |
+
if eos_token_id is not None and token_id == eos_token_id:
|
| 65 |
+
break
|
| 66 |
+
generated.append(token_id)
|
| 67 |
+
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
| 68 |
+
clear_kernel_caches(model)
|
| 69 |
+
return tokenizer.decode(generated, skip_special_tokens=True)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def main() -> None:
|
| 73 |
+
parser = argparse.ArgumentParser()
|
| 74 |
+
parser.add_argument("--checkpoint", required=True)
|
| 75 |
+
parser.add_argument("--tokenizer-path", required=True)
|
| 76 |
+
parser.add_argument("--output", required=True)
|
| 77 |
+
parser.add_argument("--prompt", action="append", default=[])
|
| 78 |
+
parser.add_argument("--max-new-tokens", type=int, default=80)
|
| 79 |
+
parser.add_argument("--temperature", type=float, default=0.8)
|
| 80 |
+
parser.add_argument("--top-p", type=float, default=0.9)
|
| 81 |
+
parser.add_argument("--device", default="cuda")
|
| 82 |
+
parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16")
|
| 83 |
+
args = parser.parse_args()
|
| 84 |
+
|
| 85 |
+
prompts = args.prompt or [
|
| 86 |
+
"The purpose of artificial intelligence is",
|
| 87 |
+
"In a small village,",
|
| 88 |
+
"<user>Hello, who are you?<assistant>",
|
| 89 |
+
]
|
| 90 |
+
device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu")
|
| 91 |
+
dtype = {
|
| 92 |
+
"float32": torch.float32,
|
| 93 |
+
"bfloat16": torch.bfloat16,
|
| 94 |
+
"float16": torch.float16,
|
| 95 |
+
}[args.dtype]
|
| 96 |
+
tokenizer = Inferencer._load_tokenizer(args.tokenizer_path)
|
| 97 |
+
checkpoint_path = Path(args.checkpoint)
|
| 98 |
+
checkpoint = CheckpointManager(checkpoint_path.parent).load(checkpoint_path, device=device)
|
| 99 |
+
model_config = ModelConfig(**checkpoint.get("config", {}).get("model", {}))
|
| 100 |
+
model = get_model(model_config, device=device)
|
| 101 |
+
model.load_state_dict(checkpoint["model_state"], strict=False)
|
| 102 |
+
|
| 103 |
+
samples = []
|
| 104 |
+
for prompt in prompts:
|
| 105 |
+
text = generate_once(
|
| 106 |
+
model,
|
| 107 |
+
tokenizer,
|
| 108 |
+
prompt,
|
| 109 |
+
device=device,
|
| 110 |
+
max_new_tokens=args.max_new_tokens,
|
| 111 |
+
temperature=args.temperature,
|
| 112 |
+
top_p=args.top_p,
|
| 113 |
+
dtype=dtype,
|
| 114 |
+
)
|
| 115 |
+
samples.append({"prompt": prompt, "completion": text})
|
| 116 |
+
|
| 117 |
+
result = {
|
| 118 |
+
"checkpoint": args.checkpoint,
|
| 119 |
+
"tokenizer_path": args.tokenizer_path,
|
| 120 |
+
"device": str(device),
|
| 121 |
+
"dtype": str(dtype),
|
| 122 |
+
"max_new_tokens": args.max_new_tokens,
|
| 123 |
+
"temperature": args.temperature,
|
| 124 |
+
"top_p": args.top_p,
|
| 125 |
+
"samples": samples,
|
| 126 |
+
}
|
| 127 |
+
output = Path(args.output)
|
| 128 |
+
output.parent.mkdir(parents=True, exist_ok=True)
|
| 129 |
+
output.write_text(json.dumps(result, indent=2, ensure_ascii=False), encoding="utf-8")
|
| 130 |
+
print(json.dumps(result, indent=2, ensure_ascii=False))
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
main()
|
code/TaoTrain/scripts/diagnostics/sft_sanity_check.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Small SFT diagnostics for checkpoint quality and trainability.
|
| 2 |
+
|
| 3 |
+
This script intentionally bypasses the full trainer so it can answer one narrow
|
| 4 |
+
question quickly: can the checkpoint reduce response-only SFT loss on a tiny,
|
| 5 |
+
fixed batch?
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
import math
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from taoTrain.checkpointing.checkpoint import CheckpointManager
|
| 19 |
+
from taoTrain.config import TrainingModeEnum, load_config
|
| 20 |
+
from taoTrain.core import create_model
|
| 21 |
+
from taoTrain.data.sft_utils import build_sft_sequence_tokens, parse_sft_record
|
| 22 |
+
try:
|
| 23 |
+
from taoTrain.data.sft_utils import build_response_only_next_token_labels
|
| 24 |
+
except ImportError:
|
| 25 |
+
def build_response_only_next_token_labels(input_ids: list[int], mask: list[int]) -> list[int]:
|
| 26 |
+
labels = [token_id if mask_value else -100 for token_id, mask_value in zip(input_ids, mask)]
|
| 27 |
+
return labels[1:] + [-100]
|
| 28 |
+
from taoTrain.data.tokenizer import SentencePieceTokenizerWrapper
|
| 29 |
+
from taoTrain.utils import set_seed
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_tokenizer(tokenizer_path: str):
|
| 33 |
+
path = Path(tokenizer_path)
|
| 34 |
+
if path.suffix == ".model":
|
| 35 |
+
import sentencepiece as spm
|
| 36 |
+
|
| 37 |
+
sp = spm.SentencePieceProcessor()
|
| 38 |
+
sp.Load(str(path))
|
| 39 |
+
return SentencePieceTokenizerWrapper(sp)
|
| 40 |
+
|
| 41 |
+
from transformers import AutoTokenizer
|
| 42 |
+
|
| 43 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
| 44 |
+
if getattr(tokenizer, "pad_token", None) is None and getattr(tokenizer, "eos_token", None):
|
| 45 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 46 |
+
return tokenizer
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def read_jsonl_records(path: str, limit: int) -> list[dict[str, Any]]:
|
| 50 |
+
records = []
|
| 51 |
+
with open(path, "r", encoding="utf-8") as handle:
|
| 52 |
+
for line in handle:
|
| 53 |
+
line = line.strip()
|
| 54 |
+
if not line:
|
| 55 |
+
continue
|
| 56 |
+
records.append(json.loads(line))
|
| 57 |
+
if len(records) >= limit:
|
| 58 |
+
break
|
| 59 |
+
return records
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def build_batch(config, tokenizer, records: list[dict[str, Any]], device: torch.device) -> dict[str, torch.Tensor]:
|
| 63 |
+
input_rows = []
|
| 64 |
+
attention_rows = []
|
| 65 |
+
label_rows = []
|
| 66 |
+
train_tokens = []
|
| 67 |
+
|
| 68 |
+
for record in records:
|
| 69 |
+
turns, _ = parse_sft_record(record, config)
|
| 70 |
+
if not turns:
|
| 71 |
+
continue
|
| 72 |
+
input_ids, attention_mask, mask = build_sft_sequence_tokens(
|
| 73 |
+
turns=turns,
|
| 74 |
+
tokenizer=tokenizer,
|
| 75 |
+
user_token=getattr(config, "user_token", "<user>"),
|
| 76 |
+
assistant_token=getattr(config, "assistant_token", "<assistant>"),
|
| 77 |
+
max_seq_length=config.model.max_seq_length,
|
| 78 |
+
)
|
| 79 |
+
labels = build_response_only_next_token_labels(input_ids, mask)
|
| 80 |
+
input_rows.append(input_ids)
|
| 81 |
+
attention_rows.append(attention_mask)
|
| 82 |
+
label_rows.append(labels)
|
| 83 |
+
train_tokens.append(sum(1 for value in labels if value != -100))
|
| 84 |
+
|
| 85 |
+
if not input_rows:
|
| 86 |
+
raise ValueError("No valid SFT records found for the diagnostic batch")
|
| 87 |
+
|
| 88 |
+
return {
|
| 89 |
+
"input_ids": torch.tensor(input_rows, dtype=torch.long, device=device),
|
| 90 |
+
"attention_mask": torch.tensor(attention_rows, dtype=torch.long, device=device),
|
| 91 |
+
"labels": torch.tensor(label_rows, dtype=torch.long, device=device),
|
| 92 |
+
"train_tokens": torch.tensor(train_tokens, dtype=torch.long),
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@torch.no_grad()
|
| 97 |
+
def score_batch(model, batch: dict[str, torch.Tensor], dtype: torch.dtype) -> float:
|
| 98 |
+
model.eval()
|
| 99 |
+
device_type = "cuda" if batch["input_ids"].is_cuda else "cpu"
|
| 100 |
+
enabled = device_type == "cuda" and dtype in (torch.float16, torch.bfloat16)
|
| 101 |
+
with torch.autocast(device_type=device_type, dtype=dtype, enabled=enabled):
|
| 102 |
+
outputs = model(
|
| 103 |
+
input_ids=batch["input_ids"],
|
| 104 |
+
attention_mask=batch["attention_mask"],
|
| 105 |
+
labels=batch["labels"],
|
| 106 |
+
)
|
| 107 |
+
return float(outputs["loss"].detach().cpu())
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def grad_l2_norm(parameters) -> float:
|
| 111 |
+
total = 0.0
|
| 112 |
+
for parameter in parameters:
|
| 113 |
+
if parameter.grad is None:
|
| 114 |
+
continue
|
| 115 |
+
grad = parameter.grad.detach()
|
| 116 |
+
total += float(torch.sum(grad.float() * grad.float()).cpu())
|
| 117 |
+
return math.sqrt(total)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def grad_summary(named_parameters, max_items: int = 12) -> dict[str, Any]:
|
| 121 |
+
groups: dict[str, dict[str, Any]] = {}
|
| 122 |
+
worst = []
|
| 123 |
+
nonfinite = []
|
| 124 |
+
|
| 125 |
+
for name, parameter in named_parameters:
|
| 126 |
+
if parameter.grad is None:
|
| 127 |
+
continue
|
| 128 |
+
grad = parameter.grad.detach().float()
|
| 129 |
+
finite = torch.isfinite(grad)
|
| 130 |
+
finite_count = int(finite.sum().cpu())
|
| 131 |
+
numel = grad.numel()
|
| 132 |
+
finite_abs_max = float(grad[finite].abs().max().cpu()) if finite_count else float("inf")
|
| 133 |
+
has_nonfinite = finite_count != numel
|
| 134 |
+
if has_nonfinite:
|
| 135 |
+
nonfinite.append(name)
|
| 136 |
+
|
| 137 |
+
if ".layers." in name:
|
| 138 |
+
parts = name.split(".")
|
| 139 |
+
try:
|
| 140 |
+
idx = parts.index("layers")
|
| 141 |
+
group = "layer_" + parts[idx + 1]
|
| 142 |
+
except (ValueError, IndexError):
|
| 143 |
+
group = "layers"
|
| 144 |
+
else:
|
| 145 |
+
group = name.split(".", 1)[0]
|
| 146 |
+
|
| 147 |
+
entry = groups.setdefault(group, {
|
| 148 |
+
"numel": 0,
|
| 149 |
+
"finite": 0,
|
| 150 |
+
"nonfinite_tensors": 0,
|
| 151 |
+
"max_abs_grad": 0.0,
|
| 152 |
+
})
|
| 153 |
+
entry["numel"] += numel
|
| 154 |
+
entry["finite"] += finite_count
|
| 155 |
+
entry["nonfinite_tensors"] += int(has_nonfinite)
|
| 156 |
+
entry["max_abs_grad"] = max(entry["max_abs_grad"], finite_abs_max)
|
| 157 |
+
worst.append((finite_abs_max, name))
|
| 158 |
+
|
| 159 |
+
worst.sort(reverse=True, key=lambda item: item[0])
|
| 160 |
+
return {
|
| 161 |
+
"groups": groups,
|
| 162 |
+
"worst_tensors": [{"name": name, "max_abs_grad": value} for value, name in worst[:max_items]],
|
| 163 |
+
"nonfinite_tensors": nonfinite[:max_items],
|
| 164 |
+
"nonfinite_tensor_count": len(nonfinite),
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def freeze_ssm_core_parameters(model) -> int:
|
| 169 |
+
frozen = 0
|
| 170 |
+
markers = (
|
| 171 |
+
".ssm_lanes.",
|
| 172 |
+
".ssm.",
|
| 173 |
+
)
|
| 174 |
+
for name, parameter in model.named_parameters():
|
| 175 |
+
if any(marker in name for marker in markers):
|
| 176 |
+
parameter.requires_grad_(False)
|
| 177 |
+
frozen += parameter.numel()
|
| 178 |
+
return frozen
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def main() -> None:
|
| 182 |
+
parser = argparse.ArgumentParser()
|
| 183 |
+
parser.add_argument("--config", required=True)
|
| 184 |
+
parser.add_argument("--checkpoint", required=True)
|
| 185 |
+
parser.add_argument("--output", required=True)
|
| 186 |
+
parser.add_argument("--samples", type=int, default=2)
|
| 187 |
+
parser.add_argument("--steps", type=int, default=80)
|
| 188 |
+
parser.add_argument("--lr", type=float, default=3e-4)
|
| 189 |
+
parser.add_argument("--log-every", type=int, default=10)
|
| 190 |
+
parser.add_argument("--device", default="cuda")
|
| 191 |
+
parser.add_argument("--dtype", choices=["config", "float32", "float16", "bfloat16"], default="config")
|
| 192 |
+
parser.add_argument("--no-clip", action="store_true")
|
| 193 |
+
parser.add_argument("--freeze-ssm-core", action="store_true")
|
| 194 |
+
parser.add_argument("--ssm-branch-rms-norm", action="store_true")
|
| 195 |
+
parser.add_argument("--ssm-branch-clip-value", type=float, default=None)
|
| 196 |
+
parser.add_argument("--block-residual-rms-norm", action="store_true")
|
| 197 |
+
parser.add_argument("--block-residual-rms-target", type=float, default=None)
|
| 198 |
+
parser.add_argument("--seed", type=int, default=123)
|
| 199 |
+
args = parser.parse_args()
|
| 200 |
+
|
| 201 |
+
set_seed(args.seed)
|
| 202 |
+
config = load_config(args.config, TrainingModeEnum.SFT)
|
| 203 |
+
if args.ssm_branch_rms_norm:
|
| 204 |
+
config.model.ssm_branch_rms_norm = True
|
| 205 |
+
if args.ssm_branch_clip_value is not None:
|
| 206 |
+
config.model.ssm_branch_clip_value = args.ssm_branch_clip_value
|
| 207 |
+
if args.block_residual_rms_norm:
|
| 208 |
+
config.model.block_residual_rms_norm = True
|
| 209 |
+
if args.block_residual_rms_target is not None:
|
| 210 |
+
config.model.block_residual_rms_target = args.block_residual_rms_target
|
| 211 |
+
device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu")
|
| 212 |
+
if args.dtype == "float32":
|
| 213 |
+
dtype = torch.float32
|
| 214 |
+
elif args.dtype == "float16":
|
| 215 |
+
dtype = torch.float16
|
| 216 |
+
elif args.dtype == "bfloat16":
|
| 217 |
+
dtype = torch.bfloat16
|
| 218 |
+
else:
|
| 219 |
+
dtype = torch.bfloat16 if str(config.dtype) == "DataTypeEnum.BFLOAT16" or str(config.dtype) == "bfloat16" else torch.float32
|
| 220 |
+
|
| 221 |
+
tokenizer = load_tokenizer(config.dataset.tokenizer_path)
|
| 222 |
+
records = read_jsonl_records(config.dataset.jsonl_path, args.samples)
|
| 223 |
+
batch = build_batch(config, tokenizer, records, device)
|
| 224 |
+
|
| 225 |
+
model = create_model(config, device)
|
| 226 |
+
checkpoint = CheckpointManager(config.checkpoint_dir).load(args.checkpoint, device=device)
|
| 227 |
+
model.load_state_dict(checkpoint["model_state"], strict=False)
|
| 228 |
+
frozen_params = freeze_ssm_core_parameters(model) if args.freeze_ssm_core else 0
|
| 229 |
+
|
| 230 |
+
initial_loss = score_batch(model, batch, dtype)
|
| 231 |
+
|
| 232 |
+
trainable_params = [parameter for parameter in model.parameters() if parameter.requires_grad]
|
| 233 |
+
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=0.0)
|
| 234 |
+
history = []
|
| 235 |
+
device_type = "cuda" if device.type == "cuda" else "cpu"
|
| 236 |
+
autocast_enabled = device_type == "cuda" and dtype in (torch.float16, torch.bfloat16)
|
| 237 |
+
|
| 238 |
+
model.train()
|
| 239 |
+
for step in range(1, args.steps + 1):
|
| 240 |
+
optimizer.zero_grad(set_to_none=True)
|
| 241 |
+
with torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled):
|
| 242 |
+
outputs = model(
|
| 243 |
+
input_ids=batch["input_ids"],
|
| 244 |
+
attention_mask=batch["attention_mask"],
|
| 245 |
+
labels=batch["labels"],
|
| 246 |
+
)
|
| 247 |
+
loss = outputs["loss"]
|
| 248 |
+
loss.backward()
|
| 249 |
+
grad_norm = grad_l2_norm(trainable_params)
|
| 250 |
+
stats = None
|
| 251 |
+
if step == 1 or step % args.log_every == 0 or step == args.steps:
|
| 252 |
+
stats = grad_summary(model.named_parameters())
|
| 253 |
+
if not args.no_clip:
|
| 254 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
|
| 255 |
+
optimizer.step()
|
| 256 |
+
|
| 257 |
+
if step == 1 or step % args.log_every == 0 or step == args.steps:
|
| 258 |
+
item = {
|
| 259 |
+
"step": step,
|
| 260 |
+
"loss": float(loss.detach().cpu()),
|
| 261 |
+
"grad_l2_norm": grad_norm,
|
| 262 |
+
}
|
| 263 |
+
if stats is not None:
|
| 264 |
+
item["grad_summary"] = stats
|
| 265 |
+
history.append(item)
|
| 266 |
+
|
| 267 |
+
final_loss = score_batch(model, batch, dtype)
|
| 268 |
+
result = {
|
| 269 |
+
"checkpoint": str(Path(args.checkpoint)),
|
| 270 |
+
"config": str(Path(args.config)),
|
| 271 |
+
"dataset": config.dataset.jsonl_path,
|
| 272 |
+
"samples": len(records),
|
| 273 |
+
"sequence_length": config.model.max_seq_length,
|
| 274 |
+
"train_tokens_per_sample": batch["train_tokens"].tolist(),
|
| 275 |
+
"lr": args.lr,
|
| 276 |
+
"steps": args.steps,
|
| 277 |
+
"clip_grad_norm": not args.no_clip,
|
| 278 |
+
"freeze_ssm_core": args.freeze_ssm_core,
|
| 279 |
+
"ssm_branch_rms_norm": config.model.ssm_branch_rms_norm,
|
| 280 |
+
"ssm_branch_clip_value": config.model.ssm_branch_clip_value,
|
| 281 |
+
"block_residual_rms_norm": config.model.block_residual_rms_norm,
|
| 282 |
+
"block_residual_rms_target": config.model.block_residual_rms_target,
|
| 283 |
+
"frozen_params": frozen_params,
|
| 284 |
+
"trainable_params": sum(parameter.numel() for parameter in trainable_params),
|
| 285 |
+
"initial_loss": initial_loss,
|
| 286 |
+
"final_loss": final_loss,
|
| 287 |
+
"loss_delta": final_loss - initial_loss,
|
| 288 |
+
"history": history,
|
| 289 |
+
"device": str(device),
|
| 290 |
+
"dtype": str(dtype),
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
output = Path(args.output)
|
| 294 |
+
output.parent.mkdir(parents=True, exist_ok=True)
|
| 295 |
+
output.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
| 296 |
+
print(json.dumps(result, indent=2))
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
if __name__ == "__main__":
|
| 300 |
+
main()
|
code/TaoTrain/scripts/profile_taonet_components.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Profile TaoNet and TaoNet-SSM component costs on synthetic token batches.
|
| 2 |
+
|
| 3 |
+
The real-token benchmark tells us end-to-end quality and throughput. This
|
| 4 |
+
script is the companion microscope: it times forward components such as the
|
| 5 |
+
SSM core, gates, projections, FFN, embeddings, and output head so hardware work
|
| 6 |
+
targets the largest measured costs.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
from contextlib import nullcontext
|
| 14 |
+
from contextlib import redirect_stdout
|
| 15 |
+
import io
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
import platform
|
| 20 |
+
import sys
|
| 21 |
+
import time
|
| 22 |
+
from typing import Any
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 27 |
+
SRC_ROOT = REPO_ROOT / "src"
|
| 28 |
+
if str(SRC_ROOT) not in sys.path:
|
| 29 |
+
sys.path.insert(0, str(SRC_ROOT))
|
| 30 |
+
|
| 31 |
+
from taoTrain.config import ModelConfig
|
| 32 |
+
from taoTrain.models import get_model
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
DTYPES = {
|
| 36 |
+
"float32": torch.float32,
|
| 37 |
+
"fp32": torch.float32,
|
| 38 |
+
"float16": torch.float16,
|
| 39 |
+
"fp16": torch.float16,
|
| 40 |
+
"bfloat16": torch.bfloat16,
|
| 41 |
+
"bf16": torch.bfloat16,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def synchronize(device: torch.device) -> None:
|
| 46 |
+
if device.type == "cuda":
|
| 47 |
+
torch.cuda.synchronize(device)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def reset_memory(device: torch.device) -> None:
|
| 51 |
+
if device.type == "cuda":
|
| 52 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def memory_stats(device: torch.device) -> dict[str, float | None]:
|
| 56 |
+
if device.type != "cuda":
|
| 57 |
+
return {"peak_allocated_mb": None, "peak_reserved_mb": None}
|
| 58 |
+
return {
|
| 59 |
+
"peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
|
| 60 |
+
"peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ComponentTimer:
|
| 65 |
+
def __init__(self, device: torch.device) -> None:
|
| 66 |
+
self.device = device
|
| 67 |
+
self.records: dict[str, list[float]] = defaultdict(list)
|
| 68 |
+
self._starts: dict[int, Any] = {}
|
| 69 |
+
self._handles = []
|
| 70 |
+
|
| 71 |
+
def _record_ms(self, name: str, start: Any) -> None:
|
| 72 |
+
if self.device.type == "cuda":
|
| 73 |
+
end = torch.cuda.Event(enable_timing=True)
|
| 74 |
+
end.record()
|
| 75 |
+
end.synchronize()
|
| 76 |
+
self.records[name].append(float(start.elapsed_time(end)))
|
| 77 |
+
else:
|
| 78 |
+
self.records[name].append((time.perf_counter() - start) * 1000.0)
|
| 79 |
+
|
| 80 |
+
def add(self, module: torch.nn.Module, name: str) -> None:
|
| 81 |
+
def pre_hook(mod, inputs):
|
| 82 |
+
del inputs
|
| 83 |
+
if self.device.type == "cuda":
|
| 84 |
+
start = torch.cuda.Event(enable_timing=True)
|
| 85 |
+
start.record()
|
| 86 |
+
else:
|
| 87 |
+
start = time.perf_counter()
|
| 88 |
+
self._starts[id(mod)] = start
|
| 89 |
+
|
| 90 |
+
def post_hook(mod, inputs, output):
|
| 91 |
+
del inputs, output
|
| 92 |
+
start = self._starts.pop(id(mod), None)
|
| 93 |
+
if start is not None:
|
| 94 |
+
self._record_ms(name, start)
|
| 95 |
+
|
| 96 |
+
self._handles.append(module.register_forward_pre_hook(pre_hook))
|
| 97 |
+
self._handles.append(module.register_forward_hook(post_hook))
|
| 98 |
+
|
| 99 |
+
def close(self) -> None:
|
| 100 |
+
for handle in self._handles:
|
| 101 |
+
handle.remove()
|
| 102 |
+
self._handles.clear()
|
| 103 |
+
|
| 104 |
+
def summary(self) -> list[dict[str, float | str | int]]:
|
| 105 |
+
rows = []
|
| 106 |
+
for name, values in sorted(self.records.items()):
|
| 107 |
+
if not values:
|
| 108 |
+
continue
|
| 109 |
+
rows.append(
|
| 110 |
+
{
|
| 111 |
+
"component": name,
|
| 112 |
+
"calls": len(values),
|
| 113 |
+
"mean_ms": sum(values) / len(values),
|
| 114 |
+
"total_ms": sum(values),
|
| 115 |
+
"min_ms": min(values),
|
| 116 |
+
"max_ms": max(values),
|
| 117 |
+
}
|
| 118 |
+
)
|
| 119 |
+
rows.sort(key=lambda row: float(row["total_ms"]), reverse=True)
|
| 120 |
+
return rows
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def build_config(args: argparse.Namespace, architecture: str) -> ModelConfig:
|
| 124 |
+
d_latent_kv = args.d_latent_kv if args.d_latent_kv is not None else int(args.hidden_dim * 0.75)
|
| 125 |
+
d_rope = args.d_rope if args.d_rope is not None else args.hidden_dim // args.num_heads
|
| 126 |
+
hidden_dim_ff = args.hidden_dim_ff if args.hidden_dim_ff is not None else args.hidden_dim * 4
|
| 127 |
+
return ModelConfig(
|
| 128 |
+
architecture_type=architecture,
|
| 129 |
+
vocab_size=args.vocab_size,
|
| 130 |
+
hidden_dim=args.hidden_dim,
|
| 131 |
+
num_layers=args.num_layers,
|
| 132 |
+
num_heads=args.num_heads,
|
| 133 |
+
max_seq_length=args.seq_len,
|
| 134 |
+
d_latent_kv=d_latent_kv,
|
| 135 |
+
d_rope=d_rope,
|
| 136 |
+
hidden_dim_ff=hidden_dim_ff,
|
| 137 |
+
dropout=args.dropout,
|
| 138 |
+
gqa_groups=args.gqa_groups,
|
| 139 |
+
rope_scale=args.rope_scale,
|
| 140 |
+
yarn_alpha=args.yarn_alpha,
|
| 141 |
+
init_std=args.init_std,
|
| 142 |
+
ssm_core=args.ssm_core,
|
| 143 |
+
ssm_hidden_dim=args.ssm_hidden_dim,
|
| 144 |
+
ssm_mixer_dim=args.ssm_mixer_dim,
|
| 145 |
+
ssm_rank=args.ssm_rank,
|
| 146 |
+
ssm_max_low_rank_scale=args.ssm_max_low_rank_scale,
|
| 147 |
+
ssm_kernel_mode=args.ssm_kernel_mode,
|
| 148 |
+
ssm_kernel_threshold=args.ssm_kernel_threshold,
|
| 149 |
+
ssm_dt_min=args.ssm_dt_min,
|
| 150 |
+
ssm_dt_max=args.ssm_dt_max,
|
| 151 |
+
ssm_dt_init=args.ssm_dt_init,
|
| 152 |
+
ssm_use_padding_mask=False,
|
| 153 |
+
ssm_activation=args.ssm_activation,
|
| 154 |
+
ssm_gate=args.ssm_gate,
|
| 155 |
+
ssm_input_gate=args.ssm_input_gate,
|
| 156 |
+
ssm_layer_scale_init=args.ssm_layer_scale_init,
|
| 157 |
+
ssm_local_shift=args.ssm_local_shift,
|
| 158 |
+
ssm_local_shift_init=args.ssm_local_shift_init,
|
| 159 |
+
ssm_local_shift_per_channel=args.ssm_local_shift_per_channel,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def add_component_hooks(model: torch.nn.Module, architecture: str, timer: ComponentTimer) -> None:
|
| 164 |
+
timer.add(model.token_embedding, "embedding")
|
| 165 |
+
timer.add(model.final_norm, "final_norm")
|
| 166 |
+
timer.add(model.output_head, "output_head")
|
| 167 |
+
for layer_index, block in enumerate(model.blocks):
|
| 168 |
+
prefix = f"block{layer_index}"
|
| 169 |
+
if architecture == "taonet_ssm":
|
| 170 |
+
mixer = block.mixer
|
| 171 |
+
timer.add(mixer.norm, f"{prefix}.mixer.norm")
|
| 172 |
+
if mixer.input_gate is not None:
|
| 173 |
+
timer.add(mixer.input_gate, f"{prefix}.mixer.input_gate")
|
| 174 |
+
timer.add(mixer.input_proj, f"{prefix}.mixer.input_proj")
|
| 175 |
+
timer.add(mixer.ssm, f"{prefix}.mixer.ssm_core")
|
| 176 |
+
timer.add(mixer.activation, f"{prefix}.mixer.activation")
|
| 177 |
+
timer.add(mixer.out_proj, f"{prefix}.mixer.out_proj")
|
| 178 |
+
if mixer.output_gate is not None:
|
| 179 |
+
timer.add(mixer.output_gate, f"{prefix}.mixer.output_gate")
|
| 180 |
+
timer.add(mixer.proj_dropout, f"{prefix}.mixer.dropout")
|
| 181 |
+
else:
|
| 182 |
+
mla = block.mla
|
| 183 |
+
timer.add(mla.norm, f"{prefix}.attention.norm")
|
| 184 |
+
timer.add(mla.q_proj, f"{prefix}.attention.q_proj")
|
| 185 |
+
timer.add(mla.k_proj, f"{prefix}.attention.k_proj")
|
| 186 |
+
timer.add(mla.v_proj, f"{prefix}.attention.v_proj")
|
| 187 |
+
timer.add(mla.out_proj, f"{prefix}.attention.out_proj")
|
| 188 |
+
timer.add(mla.attn_dropout, f"{prefix}.attention.attn_dropout")
|
| 189 |
+
timer.add(mla.proj_dropout, f"{prefix}.attention.proj_dropout")
|
| 190 |
+
timer.add(block.ff_norm, f"{prefix}.ff.norm")
|
| 191 |
+
timer.add(block.ff_gate, f"{prefix}.ff.gate")
|
| 192 |
+
timer.add(block.ff_value, f"{prefix}.ff.value")
|
| 193 |
+
timer.add(block.ff_out, f"{prefix}.ff.out")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def time_repeats(fn, *, device: torch.device, warmup: int, repeats: int) -> dict[str, float]:
|
| 197 |
+
for _ in range(warmup):
|
| 198 |
+
fn()
|
| 199 |
+
synchronize(device)
|
| 200 |
+
|
| 201 |
+
latencies = []
|
| 202 |
+
for _ in range(repeats):
|
| 203 |
+
reset_memory(device)
|
| 204 |
+
synchronize(device)
|
| 205 |
+
start = time.perf_counter()
|
| 206 |
+
fn()
|
| 207 |
+
synchronize(device)
|
| 208 |
+
latencies.append(time.perf_counter() - start)
|
| 209 |
+
mean_s = sum(latencies) / len(latencies)
|
| 210 |
+
return {
|
| 211 |
+
"mean_ms": mean_s * 1000.0,
|
| 212 |
+
"min_ms": min(latencies) * 1000.0,
|
| 213 |
+
"max_ms": max(latencies) * 1000.0,
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def profile_architecture(
|
| 218 |
+
args: argparse.Namespace,
|
| 219 |
+
*,
|
| 220 |
+
architecture: str,
|
| 221 |
+
device: torch.device,
|
| 222 |
+
dtype: torch.dtype,
|
| 223 |
+
) -> dict[str, Any]:
|
| 224 |
+
torch.manual_seed(args.seed)
|
| 225 |
+
if device.type == "cuda":
|
| 226 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 227 |
+
|
| 228 |
+
config = build_config(args, architecture)
|
| 229 |
+
with redirect_stdout(io.StringIO()):
|
| 230 |
+
model = get_model(config, device=device)
|
| 231 |
+
model.train()
|
| 232 |
+
|
| 233 |
+
input_ids = torch.randint(
|
| 234 |
+
low=0,
|
| 235 |
+
high=args.vocab_size,
|
| 236 |
+
size=(args.batch_size, args.seq_len),
|
| 237 |
+
device=device,
|
| 238 |
+
)
|
| 239 |
+
labels = torch.randint(
|
| 240 |
+
low=0,
|
| 241 |
+
high=args.vocab_size,
|
| 242 |
+
size=(args.batch_size, args.seq_len),
|
| 243 |
+
device=device,
|
| 244 |
+
)
|
| 245 |
+
attention_mask = torch.ones_like(input_ids)
|
| 246 |
+
|
| 247 |
+
autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
|
| 248 |
+
|
| 249 |
+
def autocast_context():
|
| 250 |
+
if not autocast_enabled:
|
| 251 |
+
return nullcontext()
|
| 252 |
+
return torch.autocast(device_type=device.type, dtype=dtype, enabled=True)
|
| 253 |
+
|
| 254 |
+
def forward_only() -> torch.Tensor:
|
| 255 |
+
with torch.no_grad():
|
| 256 |
+
with autocast_context():
|
| 257 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 258 |
+
return outputs["loss"]
|
| 259 |
+
|
| 260 |
+
def forward_backward() -> torch.Tensor:
|
| 261 |
+
model.zero_grad(set_to_none=True)
|
| 262 |
+
with autocast_context():
|
| 263 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 264 |
+
loss = outputs["loss"]
|
| 265 |
+
loss.backward()
|
| 266 |
+
return loss
|
| 267 |
+
|
| 268 |
+
no_timer_forward = time_repeats(
|
| 269 |
+
forward_only,
|
| 270 |
+
device=device,
|
| 271 |
+
warmup=args.warmup,
|
| 272 |
+
repeats=args.repeats,
|
| 273 |
+
)
|
| 274 |
+
no_timer_backward = time_repeats(
|
| 275 |
+
forward_backward,
|
| 276 |
+
device=device,
|
| 277 |
+
warmup=args.warmup,
|
| 278 |
+
repeats=args.repeats,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
timer = ComponentTimer(device)
|
| 282 |
+
add_component_hooks(model, architecture, timer)
|
| 283 |
+
try:
|
| 284 |
+
for _ in range(args.component_warmup):
|
| 285 |
+
forward_only()
|
| 286 |
+
synchronize(device)
|
| 287 |
+
for _ in range(args.component_repeats):
|
| 288 |
+
forward_only()
|
| 289 |
+
synchronize(device)
|
| 290 |
+
finally:
|
| 291 |
+
timer.close()
|
| 292 |
+
|
| 293 |
+
tokens = args.batch_size * args.seq_len
|
| 294 |
+
component_rows = timer.summary()
|
| 295 |
+
return {
|
| 296 |
+
"architecture": architecture,
|
| 297 |
+
"total_params": sum(param.numel() for param in model.parameters()),
|
| 298 |
+
"trainable_params": sum(param.numel() for param in model.parameters() if param.requires_grad),
|
| 299 |
+
"forward": {
|
| 300 |
+
**no_timer_forward,
|
| 301 |
+
"tokens_per_s": tokens / max(no_timer_forward["mean_ms"] / 1000.0, 1e-12),
|
| 302 |
+
},
|
| 303 |
+
"forward_backward": {
|
| 304 |
+
**no_timer_backward,
|
| 305 |
+
"tokens_per_s": tokens / max(no_timer_backward["mean_ms"] / 1000.0, 1e-12),
|
| 306 |
+
**memory_stats(device),
|
| 307 |
+
},
|
| 308 |
+
"components_forward": component_rows,
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def main() -> int:
|
| 313 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 314 |
+
parser.add_argument("--architectures", default="taonet,taonet_ssm")
|
| 315 |
+
parser.add_argument("--vocab-size", type=int, default=8192)
|
| 316 |
+
parser.add_argument("--batch-size", type=int, default=32)
|
| 317 |
+
parser.add_argument("--seq-len", type=int, default=512)
|
| 318 |
+
parser.add_argument("--hidden-dim", type=int, default=256)
|
| 319 |
+
parser.add_argument("--num-layers", type=int, default=4)
|
| 320 |
+
parser.add_argument("--num-heads", type=int, default=4)
|
| 321 |
+
parser.add_argument("--d-latent-kv", type=int, default=None)
|
| 322 |
+
parser.add_argument("--d-rope", type=int, default=None)
|
| 323 |
+
parser.add_argument("--hidden-dim-ff", type=int, default=None)
|
| 324 |
+
parser.add_argument("--dropout", type=float, default=0.0)
|
| 325 |
+
parser.add_argument("--gqa-groups", type=int, default=1)
|
| 326 |
+
parser.add_argument("--rope-scale", type=float, default=40.0)
|
| 327 |
+
parser.add_argument("--yarn-alpha", type=float, default=1.0)
|
| 328 |
+
parser.add_argument("--init-std", type=float, default=0.02)
|
| 329 |
+
parser.add_argument("--ssm-core", choices=["gamma_s4", "dplr"], default="dplr")
|
| 330 |
+
parser.add_argument("--ssm-hidden-dim", type=int, default=16)
|
| 331 |
+
parser.add_argument("--ssm-mixer-dim", type=int, default=128)
|
| 332 |
+
parser.add_argument("--ssm-rank", type=int, default=1)
|
| 333 |
+
parser.add_argument("--ssm-max-low-rank-scale", type=float, default=0.1)
|
| 334 |
+
parser.add_argument("--ssm-kernel-mode", choices=["auto", "conv", "conv_transfer", "recurrent"], default="conv")
|
| 335 |
+
parser.add_argument("--ssm-kernel-threshold", type=int, default=1)
|
| 336 |
+
parser.add_argument("--ssm-dt-min", type=float, default=1e-3)
|
| 337 |
+
parser.add_argument("--ssm-dt-max", type=float, default=1e-1)
|
| 338 |
+
parser.add_argument("--ssm-dt-init", type=float, default=1e-2)
|
| 339 |
+
parser.add_argument("--ssm-activation", choices=["gelu", "silu", "identity", "linear"], default="gelu")
|
| 340 |
+
parser.add_argument("--ssm-gate", action=argparse.BooleanOptionalAction, default=True)
|
| 341 |
+
parser.add_argument("--ssm-input-gate", action=argparse.BooleanOptionalAction, default=True)
|
| 342 |
+
parser.add_argument("--ssm-layer-scale-init", type=float, default=0.1)
|
| 343 |
+
parser.add_argument("--ssm-local-shift", action=argparse.BooleanOptionalAction, default=True)
|
| 344 |
+
parser.add_argument("--ssm-local-shift-init", type=float, default=0.1)
|
| 345 |
+
parser.add_argument("--ssm-local-shift-per-channel", action=argparse.BooleanOptionalAction, default=True)
|
| 346 |
+
parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16")
|
| 347 |
+
parser.add_argument("--device", default="auto")
|
| 348 |
+
parser.add_argument("--warmup", type=int, default=2)
|
| 349 |
+
parser.add_argument("--repeats", type=int, default=5)
|
| 350 |
+
parser.add_argument("--component-warmup", type=int, default=1)
|
| 351 |
+
parser.add_argument("--component-repeats", type=int, default=3)
|
| 352 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 353 |
+
parser.add_argument("--output", type=Path, default=None)
|
| 354 |
+
args = parser.parse_args()
|
| 355 |
+
|
| 356 |
+
if args.device == "auto":
|
| 357 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 358 |
+
else:
|
| 359 |
+
device = torch.device(args.device)
|
| 360 |
+
dtype = DTYPES[args.dtype]
|
| 361 |
+
if device.type == "cuda":
|
| 362 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 363 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 364 |
+
|
| 365 |
+
results = [
|
| 366 |
+
profile_architecture(args, architecture=architecture.strip(), device=device, dtype=dtype)
|
| 367 |
+
for architecture in args.architectures.split(",")
|
| 368 |
+
if architecture.strip()
|
| 369 |
+
]
|
| 370 |
+
report = {
|
| 371 |
+
"metadata": {
|
| 372 |
+
"python": platform.python_version(),
|
| 373 |
+
"platform": platform.platform(),
|
| 374 |
+
"torch": torch.__version__,
|
| 375 |
+
"cuda_available": torch.cuda.is_available(),
|
| 376 |
+
"cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None,
|
| 377 |
+
"device": str(device),
|
| 378 |
+
"dtype": str(dtype).replace("torch.", ""),
|
| 379 |
+
"args": vars(args) | {"output": str(args.output) if args.output else None},
|
| 380 |
+
},
|
| 381 |
+
"results": results,
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
text = json.dumps(report, indent=2, sort_keys=True, default=str)
|
| 385 |
+
print(text)
|
| 386 |
+
if args.output is not None:
|
| 387 |
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 388 |
+
args.output.write_text(text, encoding="utf-8")
|
| 389 |
+
return 0
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
if __name__ == "__main__":
|
| 393 |
+
raise SystemExit(main())
|
code/TaoTrain/scripts/remote/job_status.sh
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
RUN_ID="${RUN_ID:-}"
|
| 5 |
+
JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
|
| 6 |
+
|
| 7 |
+
if [[ -z "$RUN_ID" ]]; then
|
| 8 |
+
echo "RUN_ID is required" >&2
|
| 9 |
+
exit 2
|
| 10 |
+
fi
|
| 11 |
+
|
| 12 |
+
safe_run_id="$(printf '%s' "$RUN_ID" | tr -c 'A-Za-z0-9_.-' '_')"
|
| 13 |
+
job_dir="${JOB_ROOT%/}/${safe_run_id}"
|
| 14 |
+
|
| 15 |
+
if [[ ! -d "$job_dir" ]]; then
|
| 16 |
+
echo "Job directory not found: $job_dir" >&2
|
| 17 |
+
exit 1
|
| 18 |
+
fi
|
| 19 |
+
|
| 20 |
+
echo "== status.json =="
|
| 21 |
+
cat "$job_dir/status.json" 2>/dev/null || true
|
| 22 |
+
echo
|
| 23 |
+
echo "== markers =="
|
| 24 |
+
ls -1 "$job_dir"/DONE "$job_dir"/FAILED 2>/dev/null || true
|
| 25 |
+
echo
|
| 26 |
+
echo "== tmux =="
|
| 27 |
+
tmux ls 2>/dev/null | grep -F "taotern_${safe_run_id}" || true
|
| 28 |
+
echo
|
| 29 |
+
echo "== recent log =="
|
| 30 |
+
tail -n "${TAIL_LINES:-80}" "$job_dir/train.log" 2>/dev/null || true
|
| 31 |
+
echo
|
| 32 |
+
echo "== outputs =="
|
| 33 |
+
find "$job_dir/outputs" -maxdepth 2 -type f 2>/dev/null | sort | tail -n 40 || true
|
| 34 |
+
echo
|
| 35 |
+
echo "== checkpoints =="
|
| 36 |
+
find "$job_dir/checkpoints" -maxdepth 1 -type f 2>/dev/null | sort | tail -n 20 || true
|
code/TaoTrain/scripts/remote/run_200m_base_suite.sh
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
|
| 5 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
|
| 6 |
+
SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
|
| 7 |
+
PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
|
| 8 |
+
REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
|
| 9 |
+
OUTPUT_BASE="${REPOBRIDGE_OUTPUT_DIR:-$REMOTE_REPO/results/200m-base-suite}"
|
| 10 |
+
CHECKPOINT_BASE="${TAOTERN_CHECKPOINT_DIR:-$OUTPUT_BASE/checkpoints}"
|
| 11 |
+
|
| 12 |
+
# Stage-1 defaults are intentionally modest. Increase these through environment
|
| 13 |
+
# variables after the 200M shapes are stable on the RTX5090.
|
| 14 |
+
MAX_TOKENS="${MAX_TOKENS:-50000000}"
|
| 15 |
+
MAX_RECORDS="${MAX_RECORDS:-100000}"
|
| 16 |
+
TRAIN_STEPS="${TRAIN_STEPS:-200}"
|
| 17 |
+
EVAL_BATCHES="${EVAL_BATCHES:-16}"
|
| 18 |
+
BATCH_SIZES="${BATCH_SIZES:-4,8}"
|
| 19 |
+
SEQ_LEN="${SEQ_LEN:-512}"
|
| 20 |
+
LEARNING_RATE="${LEARNING_RATE:-0.0006}"
|
| 21 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
|
| 22 |
+
DRY_RUN="${DRY_RUN:-0}"
|
| 23 |
+
|
| 24 |
+
export PYTHONPATH="$REMOTE_REPO/src:$SSM_REPO_PATH"
|
| 25 |
+
mkdir -p "$OUTPUT_BASE" "$CHECKPOINT_BASE"
|
| 26 |
+
|
| 27 |
+
run_variant() {
|
| 28 |
+
local variant="$1"
|
| 29 |
+
shift
|
| 30 |
+
local output_dir="$OUTPUT_BASE/$variant"
|
| 31 |
+
local checkpoint_dir="$CHECKPOINT_BASE/$variant"
|
| 32 |
+
mkdir -p "$output_dir" "$checkpoint_dir"
|
| 33 |
+
|
| 34 |
+
local cmd="$PYTHON_BIN scripts/benchmark_taonet_real_tokens.py \
|
| 35 |
+
--data-path $DATA_PATH \
|
| 36 |
+
--text-field text \
|
| 37 |
+
--tokenizer-type sentencepiece \
|
| 38 |
+
--tokenizer-path $TOKENIZER_PATH \
|
| 39 |
+
--max-records $MAX_RECORDS \
|
| 40 |
+
--max-tokens $MAX_TOKENS \
|
| 41 |
+
--eval-fraction 0.1 \
|
| 42 |
+
--batch-sizes $BATCH_SIZES \
|
| 43 |
+
--seq-len $SEQ_LEN \
|
| 44 |
+
--dtype bf16 \
|
| 45 |
+
--device cuda \
|
| 46 |
+
--warmup 1 \
|
| 47 |
+
--repeats 2 \
|
| 48 |
+
--backward \
|
| 49 |
+
--train-steps $TRAIN_STEPS \
|
| 50 |
+
--learning-rate $LEARNING_RATE \
|
| 51 |
+
--weight-decay $WEIGHT_DECAY \
|
| 52 |
+
--eval-batches $EVAL_BATCHES \
|
| 53 |
+
--output-dir $output_dir \
|
| 54 |
+
--resume-completed \
|
| 55 |
+
--incremental-output \
|
| 56 |
+
--save-case-checkpoints \
|
| 57 |
+
--checkpoint-dir $checkpoint_dir \
|
| 58 |
+
$*"
|
| 59 |
+
|
| 60 |
+
printf '\n=== 200M variant: %s ===\n' "$variant"
|
| 61 |
+
printf '%s\n' "$cmd"
|
| 62 |
+
if [ "$DRY_RUN" = "1" ]; then
|
| 63 |
+
return 0
|
| 64 |
+
fi
|
| 65 |
+
eval "$cmd"
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
run_variant attention_196m \
|
| 69 |
+
--architectures taonet \
|
| 70 |
+
--hidden-dim 960 \
|
| 71 |
+
--num-layers 16 \
|
| 72 |
+
--num-heads 8 \
|
| 73 |
+
--d-latent-kv 720 \
|
| 74 |
+
--d-rope 120 \
|
| 75 |
+
--hidden-dim-ff 2880
|
| 76 |
+
|
| 77 |
+
run_variant pure_ssm_196m_hadamard \
|
| 78 |
+
--architectures taonet_ssm \
|
| 79 |
+
--hidden-dim 1024 \
|
| 80 |
+
--num-layers 18 \
|
| 81 |
+
--num-heads 8 \
|
| 82 |
+
--d-latent-kv 768 \
|
| 83 |
+
--d-rope 128 \
|
| 84 |
+
--hidden-dim-ff 3072 \
|
| 85 |
+
--ssm-core dplr \
|
| 86 |
+
--ssm-hidden-dims 16 \
|
| 87 |
+
--ssm-mixer-dims 256 \
|
| 88 |
+
--ssm-num-lanes-list 2 \
|
| 89 |
+
--ssm-lane-combine channel \
|
| 90 |
+
--ssm-lane-modes split \
|
| 91 |
+
--ssm-split-mixes hadamard \
|
| 92 |
+
--ssm-rank 1 \
|
| 93 |
+
--ssm-kernel-mode conv \
|
| 94 |
+
--no-ssm-finite-tail-correction \
|
| 95 |
+
--ssm-gate-types channel \
|
| 96 |
+
--ssm-local-shift \
|
| 97 |
+
--ssm-local-shift-per-channel \
|
| 98 |
+
--ssm-local-shift-init 0.1
|
| 99 |
+
|
| 100 |
+
run_variant pure_ssm_196m_nomix \
|
| 101 |
+
--architectures taonet_ssm \
|
| 102 |
+
--hidden-dim 1024 \
|
| 103 |
+
--num-layers 18 \
|
| 104 |
+
--num-heads 8 \
|
| 105 |
+
--d-latent-kv 768 \
|
| 106 |
+
--d-rope 128 \
|
| 107 |
+
--hidden-dim-ff 3072 \
|
| 108 |
+
--ssm-core dplr \
|
| 109 |
+
--ssm-hidden-dims 16 \
|
| 110 |
+
--ssm-mixer-dims 256 \
|
| 111 |
+
--ssm-num-lanes-list 2 \
|
| 112 |
+
--ssm-lane-combine channel \
|
| 113 |
+
--ssm-lane-modes split \
|
| 114 |
+
--ssm-split-mixes none \
|
| 115 |
+
--ssm-rank 1 \
|
| 116 |
+
--ssm-kernel-mode conv \
|
| 117 |
+
--no-ssm-finite-tail-correction \
|
| 118 |
+
--ssm-gate-types channel \
|
| 119 |
+
--ssm-local-shift \
|
| 120 |
+
--ssm-local-shift-per-channel \
|
| 121 |
+
--ssm-local-shift-init 0.1
|
| 122 |
+
|
| 123 |
+
run_variant hybrid_ssm_first_199m \
|
| 124 |
+
--architectures taonet_hybrid \
|
| 125 |
+
--hidden-dim 1024 \
|
| 126 |
+
--num-layers 16 \
|
| 127 |
+
--num-heads 8 \
|
| 128 |
+
--d-latent-kv 768 \
|
| 129 |
+
--d-rope 128 \
|
| 130 |
+
--hidden-dim-ff 3072 \
|
| 131 |
+
--ssm-core dplr \
|
| 132 |
+
--ssm-hidden-dims 32 \
|
| 133 |
+
--ssm-mixer-dims 256 \
|
| 134 |
+
--ssm-num-lanes-list 2 \
|
| 135 |
+
--ssm-lane-combine channel \
|
| 136 |
+
--ssm-lane-modes split \
|
| 137 |
+
--ssm-split-mixes hadamard \
|
| 138 |
+
--ssm-rank 1 \
|
| 139 |
+
--ssm-kernel-mode conv \
|
| 140 |
+
--no-ssm-finite-tail-correction \
|
| 141 |
+
--ssm-gate-types channel \
|
| 142 |
+
--hybrid-patterns ssm_first \
|
| 143 |
+
--ssm-local-shift \
|
| 144 |
+
--ssm-local-shift-per-channel \
|
| 145 |
+
--ssm-local-shift-init 0.1
|
| 146 |
+
|
| 147 |
+
if [ "$DRY_RUN" != "1" ]; then
|
| 148 |
+
"$PYTHON_BIN" scripts/summarize_taonet_benchmark_suite.py --suite-dir "$OUTPUT_BASE"
|
| 149 |
+
fi
|
code/TaoTrain/scripts/remote/run_200m_branch_only_chat.sh
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
|
| 5 |
+
SFT_DATA_PATH="${SFT_DATA_PATH:-/home/student/Data/TaoData/sft.jsonl}"
|
| 6 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
|
| 7 |
+
SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
|
| 8 |
+
PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
|
| 9 |
+
REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
|
| 10 |
+
OUTPUT_BASE="${REPOBRIDGE_OUTPUT_DIR:-$REMOTE_REPO/results/200m-branch-only-chat}"
|
| 11 |
+
CHECKPOINT_BASE="${TAOTERN_CHECKPOINT_DIR:-$OUTPUT_BASE/checkpoints}"
|
| 12 |
+
|
| 13 |
+
SEQ_LEN="${SEQ_LEN:-512}"
|
| 14 |
+
BATCH_SIZE="${BATCH_SIZE:-8}"
|
| 15 |
+
PRETRAIN_TOKENS="${PRETRAIN_TOKENS:-4000000000}"
|
| 16 |
+
SFT_STEPS="${SFT_STEPS:-50000}"
|
| 17 |
+
PRETRAIN_LR="${PRETRAIN_LR:-0.0008}"
|
| 18 |
+
SFT_LR="${SFT_LR:-0.00005}"
|
| 19 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
|
| 20 |
+
LOG_EVERY="${LOG_EVERY:-100}"
|
| 21 |
+
SAVE_EVERY="${SAVE_EVERY:-100000}"
|
| 22 |
+
SFT_SAVE_EVERY="${SFT_SAVE_EVERY:-10000}"
|
| 23 |
+
TOKENIZER_THREADS="${TOKENIZER_THREADS:-8}"
|
| 24 |
+
SAMPLES_PER_CHUNK="${SAMPLES_PER_CHUNK:-2000}"
|
| 25 |
+
BLOCK_RESIDUAL_RMS_CAP="${BLOCK_RESIDUAL_RMS_CAP:-}"
|
| 26 |
+
|
| 27 |
+
ceil_div() {
|
| 28 |
+
local numerator="$1"
|
| 29 |
+
local denominator="$2"
|
| 30 |
+
echo $(( (numerator + denominator - 1) / denominator ))
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
PRETRAIN_STEPS="${PRETRAIN_STEPS:-$(ceil_div "$PRETRAIN_TOKENS" $((BATCH_SIZE * SEQ_LEN)))}"
|
| 34 |
+
|
| 35 |
+
export PYTHONPATH="$REMOTE_REPO/src:$SSM_REPO_PATH"
|
| 36 |
+
mkdir -p "$OUTPUT_BASE/configs" "$OUTPUT_BASE/diagnostics" "$CHECKPOINT_BASE/pretrain" "$CHECKPOINT_BASE/sft"
|
| 37 |
+
|
| 38 |
+
cap_yaml=""
|
| 39 |
+
if [[ -n "$BLOCK_RESIDUAL_RMS_CAP" ]]; then
|
| 40 |
+
cap_yaml=" block_residual_rms_cap: $BLOCK_RESIDUAL_RMS_CAP"
|
| 41 |
+
fi
|
| 42 |
+
|
| 43 |
+
cat > "$OUTPUT_BASE/run_plan.json" <<JSON
|
| 44 |
+
{
|
| 45 |
+
"purpose": "200m_branch_only_pure_ssm_4b_pretrain_plus_sft_chatbot_attempt",
|
| 46 |
+
"candidate": "pure_ssm_196m_branch_rms_only",
|
| 47 |
+
"pretrain_token_positions": $PRETRAIN_TOKENS,
|
| 48 |
+
"pretrain_steps": $PRETRAIN_STEPS,
|
| 49 |
+
"batch_size": $BATCH_SIZE,
|
| 50 |
+
"seq_len": $SEQ_LEN,
|
| 51 |
+
"sft_steps": $SFT_STEPS,
|
| 52 |
+
"pretrain_lr": $PRETRAIN_LR,
|
| 53 |
+
"sft_lr": $SFT_LR,
|
| 54 |
+
"save_every_steps": $SAVE_EVERY,
|
| 55 |
+
"block_residual_rms_cap": "${BLOCK_RESIDUAL_RMS_CAP:-null}",
|
| 56 |
+
"notes": [
|
| 57 |
+
"Uses streaming JSONL training through TaoTrain CLI.",
|
| 58 |
+
"Counts token positions as batch_size * seq_len * max_steps.",
|
| 59 |
+
"Keeps SSM branch RMS normalization enabled and block residual RMS normalization disabled.",
|
| 60 |
+
"Runs corrected response-only SFT after pretraining."
|
| 61 |
+
]
|
| 62 |
+
}
|
| 63 |
+
JSON
|
| 64 |
+
|
| 65 |
+
PRETRAIN_CONFIG="$OUTPUT_BASE/configs/pretrain.yaml"
|
| 66 |
+
cat > "$PRETRAIN_CONFIG" <<YAML
|
| 67 |
+
model:
|
| 68 |
+
architecture_type: taonet_ssm
|
| 69 |
+
vocab_size: 8192
|
| 70 |
+
hidden_dim: 1024
|
| 71 |
+
num_layers: 18
|
| 72 |
+
num_heads: 8
|
| 73 |
+
max_seq_length: $SEQ_LEN
|
| 74 |
+
d_latent_kv: 768
|
| 75 |
+
d_rope: 128
|
| 76 |
+
hidden_dim_ff: 3072
|
| 77 |
+
dropout: 0.0
|
| 78 |
+
gqa_groups: 1
|
| 79 |
+
use_factorized_embedding: false
|
| 80 |
+
d_embed_rank: 96
|
| 81 |
+
init_std: 0.02
|
| 82 |
+
ssm_core: dplr
|
| 83 |
+
ssm_hidden_dim: 32
|
| 84 |
+
ssm_mixer_dim: 256
|
| 85 |
+
ssm_num_lanes: 2
|
| 86 |
+
ssm_lane_combine: channel
|
| 87 |
+
ssm_lane_mode: split
|
| 88 |
+
ssm_split_mix: none
|
| 89 |
+
ssm_rank: 1
|
| 90 |
+
ssm_max_low_rank_scale: 0.1
|
| 91 |
+
ssm_finite_tail_correction: false
|
| 92 |
+
ssm_discretization: bilinear
|
| 93 |
+
ssm_kernel_mode: conv
|
| 94 |
+
ssm_kernel_threshold: 64
|
| 95 |
+
ssm_dt_min: 1e-3
|
| 96 |
+
ssm_dt_max: 1e-1
|
| 97 |
+
ssm_dt_init: 1e-2
|
| 98 |
+
ssm_use_d: true
|
| 99 |
+
ssm_activation: gelu
|
| 100 |
+
ssm_gate: true
|
| 101 |
+
ssm_input_gate: true
|
| 102 |
+
ssm_gate_type: channel
|
| 103 |
+
ssm_use_padding_mask: false
|
| 104 |
+
ssm_layer_scale_init: 0.1
|
| 105 |
+
ssm_branch_rms_norm: true
|
| 106 |
+
block_residual_rms_norm: false
|
| 107 |
+
$cap_yaml
|
| 108 |
+
ssm_local_shift: true
|
| 109 |
+
ssm_local_shift_init: 0.1
|
| 110 |
+
ssm_local_shift_per_channel: true
|
| 111 |
+
|
| 112 |
+
dataset:
|
| 113 |
+
local: true
|
| 114 |
+
jsonl_path: $DATA_PATH
|
| 115 |
+
text_field: text
|
| 116 |
+
tokenizer_type: sentencepiece
|
| 117 |
+
tokenizer_path: $TOKENIZER_PATH
|
| 118 |
+
samples_per_chunk: $SAMPLES_PER_CHUNK
|
| 119 |
+
tokenizer_threads: $TOKENIZER_THREADS
|
| 120 |
+
|
| 121 |
+
sequence_length: $SEQ_LEN
|
| 122 |
+
batch_size: $BATCH_SIZE
|
| 123 |
+
num_epochs: 100000
|
| 124 |
+
max_steps: $PRETRAIN_STEPS
|
| 125 |
+
gradient_accumulation_steps: 1
|
| 126 |
+
max_grad_norm: 1.0
|
| 127 |
+
|
| 128 |
+
optimizer:
|
| 129 |
+
optimizer_type: adamw
|
| 130 |
+
learning_rate: $PRETRAIN_LR
|
| 131 |
+
weight_decay: $WEIGHT_DECAY
|
| 132 |
+
betas: [0.9, 0.999]
|
| 133 |
+
eps: 1e-8
|
| 134 |
+
|
| 135 |
+
scheduler:
|
| 136 |
+
scheduler_type: constant
|
| 137 |
+
warmup_steps: 1000
|
| 138 |
+
warmup_ratio: 0.0
|
| 139 |
+
|
| 140 |
+
dtype: bfloat16
|
| 141 |
+
device: cuda
|
| 142 |
+
checkpoint_dir: $CHECKPOINT_BASE/pretrain
|
| 143 |
+
save_every_steps: $SAVE_EVERY
|
| 144 |
+
save_best_model: false
|
| 145 |
+
keep_last_n_checkpoints: 3
|
| 146 |
+
eval_every_steps: $SAVE_EVERY
|
| 147 |
+
eval_samples: 32
|
| 148 |
+
log_every_steps: $LOG_EVERY
|
| 149 |
+
aim_repo: $OUTPUT_BASE/.aim-pretrain
|
| 150 |
+
seed: 43
|
| 151 |
+
num_workers: 0
|
| 152 |
+
pin_memory: true
|
| 153 |
+
YAML
|
| 154 |
+
|
| 155 |
+
printf '\n============================================================\n'
|
| 156 |
+
printf '200M branch-only pure SSM pretrain\n'
|
| 157 |
+
printf 'token_positions=%s batch=%s seq_len=%s steps=%s save_every=%s\n' \
|
| 158 |
+
"$PRETRAIN_TOKENS" "$BATCH_SIZE" "$SEQ_LEN" "$PRETRAIN_STEPS" "$SAVE_EVERY"
|
| 159 |
+
printf '============================================================\n'
|
| 160 |
+
|
| 161 |
+
"$PYTHON_BIN" -m taoTrain.cli pretrain --config "$PRETRAIN_CONFIG"
|
| 162 |
+
|
| 163 |
+
PRETRAIN_CKPT="$CHECKPOINT_BASE/pretrain/final_model.pt"
|
| 164 |
+
if [[ ! -f "$PRETRAIN_CKPT" ]]; then
|
| 165 |
+
echo "Expected pretrain checkpoint missing: $PRETRAIN_CKPT" >&2
|
| 166 |
+
exit 2
|
| 167 |
+
fi
|
| 168 |
+
|
| 169 |
+
"$PYTHON_BIN" scripts/diagnostics/activation_probe.py \
|
| 170 |
+
--checkpoint "$PRETRAIN_CKPT" \
|
| 171 |
+
--tokenizer-path "$TOKENIZER_PATH" \
|
| 172 |
+
--data-path "$DATA_PATH" \
|
| 173 |
+
--text-field text \
|
| 174 |
+
--output "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_final.json" \
|
| 175 |
+
--batch-size 2 \
|
| 176 |
+
--seq-len "$SEQ_LEN" \
|
| 177 |
+
--device cuda \
|
| 178 |
+
--dtype bfloat16
|
| 179 |
+
|
| 180 |
+
"$PYTHON_BIN" scripts/diagnostics/generate_checkpoint_samples.py \
|
| 181 |
+
--checkpoint "$PRETRAIN_CKPT" \
|
| 182 |
+
--tokenizer-path "$TOKENIZER_PATH" \
|
| 183 |
+
--output "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_final.json" \
|
| 184 |
+
--max-new-tokens 160 \
|
| 185 |
+
--temperature 0.8 \
|
| 186 |
+
--top-p 0.9 \
|
| 187 |
+
--prompt "The purpose of artificial intelligence is" \
|
| 188 |
+
--prompt "In a small village," \
|
| 189 |
+
--prompt "<user>Hello, who are you?<assistant>"
|
| 190 |
+
|
| 191 |
+
SFT_CONFIG="$OUTPUT_BASE/configs/sft.yaml"
|
| 192 |
+
cat > "$SFT_CONFIG" <<YAML
|
| 193 |
+
model:
|
| 194 |
+
architecture_type: taonet_ssm
|
| 195 |
+
vocab_size: 8192
|
| 196 |
+
hidden_dim: 1024
|
| 197 |
+
num_layers: 18
|
| 198 |
+
num_heads: 8
|
| 199 |
+
max_seq_length: $SEQ_LEN
|
| 200 |
+
d_latent_kv: 768
|
| 201 |
+
d_rope: 128
|
| 202 |
+
hidden_dim_ff: 3072
|
| 203 |
+
dropout: 0.0
|
| 204 |
+
gqa_groups: 1
|
| 205 |
+
use_factorized_embedding: false
|
| 206 |
+
d_embed_rank: 96
|
| 207 |
+
init_std: 0.02
|
| 208 |
+
ssm_core: dplr
|
| 209 |
+
ssm_hidden_dim: 32
|
| 210 |
+
ssm_mixer_dim: 256
|
| 211 |
+
ssm_num_lanes: 2
|
| 212 |
+
ssm_lane_combine: channel
|
| 213 |
+
ssm_lane_mode: split
|
| 214 |
+
ssm_split_mix: none
|
| 215 |
+
ssm_rank: 1
|
| 216 |
+
ssm_max_low_rank_scale: 0.1
|
| 217 |
+
ssm_finite_tail_correction: false
|
| 218 |
+
ssm_discretization: bilinear
|
| 219 |
+
ssm_kernel_mode: conv
|
| 220 |
+
ssm_kernel_threshold: 64
|
| 221 |
+
ssm_dt_min: 1e-3
|
| 222 |
+
ssm_dt_max: 1e-1
|
| 223 |
+
ssm_dt_init: 1e-2
|
| 224 |
+
ssm_use_d: true
|
| 225 |
+
ssm_activation: gelu
|
| 226 |
+
ssm_gate: true
|
| 227 |
+
ssm_input_gate: true
|
| 228 |
+
ssm_gate_type: channel
|
| 229 |
+
ssm_use_padding_mask: false
|
| 230 |
+
ssm_layer_scale_init: 0.1
|
| 231 |
+
ssm_branch_rms_norm: true
|
| 232 |
+
block_residual_rms_norm: false
|
| 233 |
+
$cap_yaml
|
| 234 |
+
ssm_local_shift: true
|
| 235 |
+
ssm_local_shift_init: 0.1
|
| 236 |
+
ssm_local_shift_per_channel: true
|
| 237 |
+
|
| 238 |
+
dataset:
|
| 239 |
+
split: train
|
| 240 |
+
instruction_column: input
|
| 241 |
+
response_column: output
|
| 242 |
+
local: true
|
| 243 |
+
jsonl_path: $SFT_DATA_PATH
|
| 244 |
+
samples_per_chunk: $SAMPLES_PER_CHUNK
|
| 245 |
+
tokenizer_type: sentencepiece
|
| 246 |
+
tokenizer_path: $TOKENIZER_PATH
|
| 247 |
+
tokenizer_threads: $TOKENIZER_THREADS
|
| 248 |
+
|
| 249 |
+
checkpoint_path: $PRETRAIN_CKPT
|
| 250 |
+
user_token: "<user>"
|
| 251 |
+
assistant_token: "<assistant>"
|
| 252 |
+
response_loss_only: true
|
| 253 |
+
|
| 254 |
+
batch_size: $BATCH_SIZE
|
| 255 |
+
num_epochs: 100000
|
| 256 |
+
max_steps: $SFT_STEPS
|
| 257 |
+
gradient_accumulation_steps: 1
|
| 258 |
+
max_grad_norm: 1.0
|
| 259 |
+
|
| 260 |
+
optimizer:
|
| 261 |
+
optimizer_type: adamw
|
| 262 |
+
learning_rate: $SFT_LR
|
| 263 |
+
weight_decay: 0.0
|
| 264 |
+
betas: [0.9, 0.999]
|
| 265 |
+
eps: 1e-8
|
| 266 |
+
|
| 267 |
+
scheduler:
|
| 268 |
+
scheduler_type: constant
|
| 269 |
+
warmup_steps: 100
|
| 270 |
+
warmup_ratio: 0.0
|
| 271 |
+
|
| 272 |
+
dtype: bfloat16
|
| 273 |
+
device: cuda
|
| 274 |
+
checkpoint_dir: $CHECKPOINT_BASE/sft
|
| 275 |
+
save_every_steps: $SFT_SAVE_EVERY
|
| 276 |
+
save_best_model: false
|
| 277 |
+
keep_last_n_checkpoints: 3
|
| 278 |
+
eval_every_steps: $SFT_SAVE_EVERY
|
| 279 |
+
eval_samples: 32
|
| 280 |
+
log_every_steps: 20
|
| 281 |
+
aim_repo: $OUTPUT_BASE/.aim-sft
|
| 282 |
+
seed: 44
|
| 283 |
+
num_workers: 0
|
| 284 |
+
pin_memory: true
|
| 285 |
+
YAML
|
| 286 |
+
|
| 287 |
+
printf '\n============================================================\n'
|
| 288 |
+
printf '200M branch-only pure SSM response-only SFT\n'
|
| 289 |
+
printf 'steps=%s batch=%s seq_len=%s save_every=%s\n' \
|
| 290 |
+
"$SFT_STEPS" "$BATCH_SIZE" "$SEQ_LEN" "$SFT_SAVE_EVERY"
|
| 291 |
+
printf '============================================================\n'
|
| 292 |
+
|
| 293 |
+
"$PYTHON_BIN" -m taoTrain.cli sft --config "$SFT_CONFIG"
|
| 294 |
+
|
| 295 |
+
SFT_CKPT="$CHECKPOINT_BASE/sft/final_model.pt"
|
| 296 |
+
if [[ ! -f "$SFT_CKPT" ]]; then
|
| 297 |
+
echo "Expected SFT checkpoint missing: $SFT_CKPT" >&2
|
| 298 |
+
exit 2
|
| 299 |
+
fi
|
| 300 |
+
|
| 301 |
+
"$PYTHON_BIN" scripts/diagnostics/generate_checkpoint_samples.py \
|
| 302 |
+
--checkpoint "$SFT_CKPT" \
|
| 303 |
+
--tokenizer-path "$TOKENIZER_PATH" \
|
| 304 |
+
--output "$OUTPUT_BASE/diagnostics/generation_samples_sft_final.json" \
|
| 305 |
+
--max-new-tokens 220 \
|
| 306 |
+
--temperature 0.7 \
|
| 307 |
+
--top-p 0.9 \
|
| 308 |
+
--prompt "<user>Hello, who are you?<assistant>" \
|
| 309 |
+
--prompt "<user>Explain what artificial intelligence is in simple words.<assistant>" \
|
| 310 |
+
--prompt "<user>Give me three practical study tips.<assistant>"
|
| 311 |
+
|
| 312 |
+
cat > "$OUTPUT_BASE/model_card.json" <<JSON
|
| 313 |
+
{
|
| 314 |
+
"state": "completed",
|
| 315 |
+
"architecture": "taonet_ssm",
|
| 316 |
+
"candidate": "pure_ssm_196m_branch_rms_only",
|
| 317 |
+
"pretrain_checkpoint": "$PRETRAIN_CKPT",
|
| 318 |
+
"sft_checkpoint": "$SFT_CKPT",
|
| 319 |
+
"tokenizer_path": "$TOKENIZER_PATH",
|
| 320 |
+
"pretrain_config": "$PRETRAIN_CONFIG",
|
| 321 |
+
"sft_config": "$SFT_CONFIG",
|
| 322 |
+
"activation_probe": "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_final.json",
|
| 323 |
+
"pretrain_samples": "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_final.json",
|
| 324 |
+
"sft_samples": "$OUTPUT_BASE/diagnostics/generation_samples_sft_final.json"
|
| 325 |
+
}
|
| 326 |
+
JSON
|
| 327 |
+
|
| 328 |
+
echo "200M branch-only pure SSM pretrain+SFT run completed."
|
| 329 |
+
cat "$OUTPUT_BASE/model_card.json"
|
code/TaoTrain/scripts/remote/run_200m_until_selection.sh
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT_OUTPUT_DIR="${REPOBRIDGE_OUTPUT_DIR:-$(pwd)/results/200m-until-selection}"
|
| 5 |
+
ROOT_CHECKPOINT_DIR="${TAOTERN_CHECKPOINT_DIR:-$ROOT_OUTPUT_DIR/checkpoints}"
|
| 6 |
+
SEQ_LEN="${SEQ_LEN:-512}"
|
| 7 |
+
BATCH_SIZE="${BATCH_SIZE:-8}"
|
| 8 |
+
PILOT_TOKENS="${PILOT_TOKENS:-300000000}"
|
| 9 |
+
SERIOUS_TOKENS="${SERIOUS_TOKENS:-1000000000}"
|
| 10 |
+
PILOT_EVAL_BATCHES="${PILOT_EVAL_BATCHES:-64}"
|
| 11 |
+
SERIOUS_EVAL_BATCHES="${SERIOUS_EVAL_BATCHES:-128}"
|
| 12 |
+
LEARNING_RATE="${LEARNING_RATE:-0.0003}"
|
| 13 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
|
| 14 |
+
|
| 15 |
+
ceil_div() {
|
| 16 |
+
local numerator="$1"
|
| 17 |
+
local denominator="$2"
|
| 18 |
+
echo $(( (numerator + denominator - 1) / denominator ))
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
run_phase() {
|
| 22 |
+
local phase="$1"
|
| 23 |
+
local target_tokens="$2"
|
| 24 |
+
local eval_batches="$3"
|
| 25 |
+
local tokens_per_step=$((BATCH_SIZE * SEQ_LEN))
|
| 26 |
+
local train_steps
|
| 27 |
+
train_steps="$(ceil_div "$target_tokens" "$tokens_per_step")"
|
| 28 |
+
|
| 29 |
+
printf '\n============================================================\n'
|
| 30 |
+
printf '200M until-selection phase: %s\n' "$phase"
|
| 31 |
+
printf 'target_tokens=%s batch=%s seq_len=%s train_steps=%s eval_batches=%s\n' \
|
| 32 |
+
"$target_tokens" "$BATCH_SIZE" "$SEQ_LEN" "$train_steps" "$eval_batches"
|
| 33 |
+
printf '============================================================\n'
|
| 34 |
+
|
| 35 |
+
REPOBRIDGE_OUTPUT_DIR="$ROOT_OUTPUT_DIR/$phase" \
|
| 36 |
+
TAOTERN_CHECKPOINT_DIR="$ROOT_CHECKPOINT_DIR/$phase" \
|
| 37 |
+
BATCH_SIZES="$BATCH_SIZE" \
|
| 38 |
+
SEQ_LEN="$SEQ_LEN" \
|
| 39 |
+
TRAIN_STEPS="$train_steps" \
|
| 40 |
+
EVAL_BATCHES="$eval_batches" \
|
| 41 |
+
LEARNING_RATE="$LEARNING_RATE" \
|
| 42 |
+
WEIGHT_DECAY="$WEIGHT_DECAY" \
|
| 43 |
+
bash scripts/remote/run_200m_base_suite.sh
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
mkdir -p "$ROOT_OUTPUT_DIR" "$ROOT_CHECKPOINT_DIR"
|
| 47 |
+
|
| 48 |
+
cat > "$ROOT_OUTPUT_DIR/run_plan.json" <<JSON
|
| 49 |
+
{
|
| 50 |
+
"stopping_point": "selection_after_1b_all_four_variants",
|
| 51 |
+
"batch_size": $BATCH_SIZE,
|
| 52 |
+
"seq_len": $SEQ_LEN,
|
| 53 |
+
"learning_rate": $LEARNING_RATE,
|
| 54 |
+
"weight_decay": $WEIGHT_DECAY,
|
| 55 |
+
"phases": [
|
| 56 |
+
{
|
| 57 |
+
"name": "pilot_300m",
|
| 58 |
+
"target_tokens_per_variant": $PILOT_TOKENS,
|
| 59 |
+
"train_steps": $(ceil_div "$PILOT_TOKENS" $((BATCH_SIZE * SEQ_LEN))),
|
| 60 |
+
"eval_batches": $PILOT_EVAL_BATCHES
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"name": "serious_1b",
|
| 64 |
+
"target_tokens_per_variant": $SERIOUS_TOKENS,
|
| 65 |
+
"train_steps": $(ceil_div "$SERIOUS_TOKENS" $((BATCH_SIZE * SEQ_LEN))),
|
| 66 |
+
"eval_batches": $SERIOUS_EVAL_BATCHES
|
| 67 |
+
}
|
| 68 |
+
],
|
| 69 |
+
"variants": [
|
| 70 |
+
"attention_196m",
|
| 71 |
+
"pure_ssm_196m_hadamard",
|
| 72 |
+
"pure_ssm_196m_nomix",
|
| 73 |
+
"hybrid_ssm_first_199m"
|
| 74 |
+
]
|
| 75 |
+
}
|
| 76 |
+
JSON
|
| 77 |
+
|
| 78 |
+
run_phase "pilot_300m" "$PILOT_TOKENS" "$PILOT_EVAL_BATCHES"
|
| 79 |
+
run_phase "serious_1b" "$SERIOUS_TOKENS" "$SERIOUS_EVAL_BATCHES"
|
| 80 |
+
|
| 81 |
+
echo "Selection gate reached after pilot_300m and serious_1b completed for all four variants."
|
code/TaoTrain/scripts/remote/run_pre_200m_branch_only_gate.sh
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
|
| 5 |
+
SFT_DATA_PATH="${SFT_DATA_PATH:-/home/student/Data/TaoData/sft.jsonl}"
|
| 6 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
|
| 7 |
+
SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
|
| 8 |
+
PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
|
| 9 |
+
REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
|
| 10 |
+
OUTPUT_BASE="${REPOBRIDGE_OUTPUT_DIR:-$REMOTE_REPO/results/pre-200m-branch-only-gate}"
|
| 11 |
+
CHECKPOINT_BASE="${TAOTERN_CHECKPOINT_DIR:-$OUTPUT_BASE/checkpoints}"
|
| 12 |
+
|
| 13 |
+
SEQ_LEN="${SEQ_LEN:-512}"
|
| 14 |
+
BATCH_SIZE="${BATCH_SIZE:-8}"
|
| 15 |
+
TARGET_TOKENS="${TARGET_TOKENS:-100000000}"
|
| 16 |
+
MAX_TOKENS="${MAX_TOKENS:-100000000}"
|
| 17 |
+
MAX_RECORDS="${MAX_RECORDS:-200000}"
|
| 18 |
+
EVAL_BATCHES="${EVAL_BATCHES:-128}"
|
| 19 |
+
LEARNING_RATE="${LEARNING_RATE:-0.0008}"
|
| 20 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
|
| 21 |
+
TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-1000}"
|
| 22 |
+
SFT_SANITY_SAMPLES="${SFT_SANITY_SAMPLES:-4}"
|
| 23 |
+
SFT_SANITY_STEPS="${SFT_SANITY_STEPS:-120}"
|
| 24 |
+
SFT_SANITY_LR="${SFT_SANITY_LR:-0.00005}"
|
| 25 |
+
BLOCK_RESIDUAL_RMS_CAP="${BLOCK_RESIDUAL_RMS_CAP:-}"
|
| 26 |
+
|
| 27 |
+
ceil_div() {
|
| 28 |
+
local numerator="$1"
|
| 29 |
+
local denominator="$2"
|
| 30 |
+
echo $(( (numerator + denominator - 1) / denominator ))
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
TRAIN_STEPS="${TRAIN_STEPS:-$(ceil_div "$TARGET_TOKENS" $((BATCH_SIZE * SEQ_LEN)))}"
|
| 34 |
+
|
| 35 |
+
export PYTHONPATH="$REMOTE_REPO/src:$SSM_REPO_PATH"
|
| 36 |
+
mkdir -p "$OUTPUT_BASE" "$CHECKPOINT_BASE" "$OUTPUT_BASE/configs" "$OUTPUT_BASE/diagnostics"
|
| 37 |
+
|
| 38 |
+
cap_args=""
|
| 39 |
+
cap_yaml=""
|
| 40 |
+
if [[ -n "$BLOCK_RESIDUAL_RMS_CAP" ]]; then
|
| 41 |
+
cap_args="--block-residual-rms-cap $BLOCK_RESIDUAL_RMS_CAP"
|
| 42 |
+
cap_yaml=" block_residual_rms_cap: $BLOCK_RESIDUAL_RMS_CAP"
|
| 43 |
+
fi
|
| 44 |
+
|
| 45 |
+
cat > "$OUTPUT_BASE/run_plan.json" <<JSON
|
| 46 |
+
{
|
| 47 |
+
"purpose": "pre_200m_branch_only_quality_gate_before_4b_sft_chatbot_run",
|
| 48 |
+
"candidate": "pure_ssm_196m_branch_rms_only",
|
| 49 |
+
"target_tokens": $TARGET_TOKENS,
|
| 50 |
+
"train_steps": $TRAIN_STEPS,
|
| 51 |
+
"batch_size": $BATCH_SIZE,
|
| 52 |
+
"seq_len": $SEQ_LEN,
|
| 53 |
+
"learning_rate": $LEARNING_RATE,
|
| 54 |
+
"weight_decay": $WEIGHT_DECAY,
|
| 55 |
+
"block_residual_rms_cap": "${BLOCK_RESIDUAL_RMS_CAP:-null}",
|
| 56 |
+
"checks": [
|
| 57 |
+
"bounded pretrain loss/eval/grad telemetry",
|
| 58 |
+
"activation scale probe",
|
| 59 |
+
"sample generation",
|
| 60 |
+
"tiny SFT overfit probe"
|
| 61 |
+
]
|
| 62 |
+
}
|
| 63 |
+
JSON
|
| 64 |
+
|
| 65 |
+
printf '\n============================================================\n'
|
| 66 |
+
printf 'Pre-200M branch-only gate: pure SSM stabilized candidate\n'
|
| 67 |
+
printf 'target_tokens=%s batch=%s seq_len=%s train_steps=%s eval_batches=%s cap=%s\n' \
|
| 68 |
+
"$TARGET_TOKENS" "$BATCH_SIZE" "$SEQ_LEN" "$TRAIN_STEPS" "$EVAL_BATCHES" "${BLOCK_RESIDUAL_RMS_CAP:-none}"
|
| 69 |
+
printf '============================================================\n'
|
| 70 |
+
|
| 71 |
+
"$PYTHON_BIN" scripts/benchmark_taonet_real_tokens.py \
|
| 72 |
+
--data-path "$DATA_PATH" \
|
| 73 |
+
--text-field text \
|
| 74 |
+
--tokenizer-type sentencepiece \
|
| 75 |
+
--tokenizer-path "$TOKENIZER_PATH" \
|
| 76 |
+
--max-records "$MAX_RECORDS" \
|
| 77 |
+
--max-tokens "$MAX_TOKENS" \
|
| 78 |
+
--eval-fraction 0.1 \
|
| 79 |
+
--architectures taonet_ssm \
|
| 80 |
+
--batch-sizes "$BATCH_SIZE" \
|
| 81 |
+
--seq-len "$SEQ_LEN" \
|
| 82 |
+
--hidden-dim 1024 \
|
| 83 |
+
--num-layers 18 \
|
| 84 |
+
--num-heads 8 \
|
| 85 |
+
--d-latent-kv 768 \
|
| 86 |
+
--d-rope 128 \
|
| 87 |
+
--hidden-dim-ff 3072 \
|
| 88 |
+
--dropout 0.0 \
|
| 89 |
+
--ssm-core dplr \
|
| 90 |
+
--ssm-hidden-dims 32 \
|
| 91 |
+
--ssm-mixer-dims 256 \
|
| 92 |
+
--ssm-num-lanes-list 2 \
|
| 93 |
+
--ssm-lane-combine channel \
|
| 94 |
+
--ssm-lane-modes split \
|
| 95 |
+
--ssm-split-mixes none \
|
| 96 |
+
--ssm-rank 1 \
|
| 97 |
+
--ssm-kernel-mode conv \
|
| 98 |
+
--no-ssm-finite-tail-correction \
|
| 99 |
+
--ssm-gate-types channel \
|
| 100 |
+
--dtype bf16 \
|
| 101 |
+
--device cuda \
|
| 102 |
+
--warmup 1 \
|
| 103 |
+
--repeats 2 \
|
| 104 |
+
--backward \
|
| 105 |
+
--train-steps "$TRAIN_STEPS" \
|
| 106 |
+
--train-log-every "$TRAIN_LOG_EVERY" \
|
| 107 |
+
--learning-rate "$LEARNING_RATE" \
|
| 108 |
+
--weight-decay "$WEIGHT_DECAY" \
|
| 109 |
+
--max-grad-norm 1.0 \
|
| 110 |
+
--eval-batches "$EVAL_BATCHES" \
|
| 111 |
+
--ssm-local-shift \
|
| 112 |
+
--ssm-local-shift-per-channel \
|
| 113 |
+
--ssm-local-shift-init 0.1 \
|
| 114 |
+
--ssm-branch-rms-norm \
|
| 115 |
+
$cap_args \
|
| 116 |
+
--output-dir "$OUTPUT_BASE/pretrain" \
|
| 117 |
+
--resume-completed \
|
| 118 |
+
--incremental-output \
|
| 119 |
+
--save-case-checkpoints \
|
| 120 |
+
--checkpoint-dir "$CHECKPOINT_BASE/pretrain"
|
| 121 |
+
|
| 122 |
+
PRETRAIN_CKPT="$CHECKPOINT_BASE/pretrain/latest.pt"
|
| 123 |
+
if [[ ! -f "$PRETRAIN_CKPT" ]]; then
|
| 124 |
+
echo "Expected pretrain checkpoint missing: $PRETRAIN_CKPT" >&2
|
| 125 |
+
exit 2
|
| 126 |
+
fi
|
| 127 |
+
|
| 128 |
+
"$PYTHON_BIN" scripts/diagnostics/activation_probe.py \
|
| 129 |
+
--checkpoint "$PRETRAIN_CKPT" \
|
| 130 |
+
--tokenizer-path "$TOKENIZER_PATH" \
|
| 131 |
+
--data-path "$DATA_PATH" \
|
| 132 |
+
--text-field text \
|
| 133 |
+
--output "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_latest.json" \
|
| 134 |
+
--batch-size 2 \
|
| 135 |
+
--seq-len "$SEQ_LEN" \
|
| 136 |
+
--device cuda \
|
| 137 |
+
--dtype bfloat16
|
| 138 |
+
|
| 139 |
+
"$PYTHON_BIN" scripts/diagnostics/generate_checkpoint_samples.py \
|
| 140 |
+
--checkpoint "$PRETRAIN_CKPT" \
|
| 141 |
+
--tokenizer-path "$TOKENIZER_PATH" \
|
| 142 |
+
--output "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_latest.json" \
|
| 143 |
+
--max-new-tokens 120 \
|
| 144 |
+
--temperature 0.8 \
|
| 145 |
+
--top-p 0.9 \
|
| 146 |
+
--prompt "The purpose of artificial intelligence is" \
|
| 147 |
+
--prompt "In a small village," \
|
| 148 |
+
--prompt "<user>Hello, who are you?<assistant>"
|
| 149 |
+
|
| 150 |
+
SFT_CONFIG="$OUTPUT_BASE/configs/sft_sanity.yaml"
|
| 151 |
+
cat > "$SFT_CONFIG" <<YAML
|
| 152 |
+
model:
|
| 153 |
+
architecture_type: taonet_ssm
|
| 154 |
+
vocab_size: 8192
|
| 155 |
+
hidden_dim: 1024
|
| 156 |
+
num_layers: 18
|
| 157 |
+
num_heads: 8
|
| 158 |
+
max_seq_length: $SEQ_LEN
|
| 159 |
+
d_latent_kv: 768
|
| 160 |
+
d_rope: 128
|
| 161 |
+
hidden_dim_ff: 3072
|
| 162 |
+
dropout: 0.0
|
| 163 |
+
gqa_groups: 1
|
| 164 |
+
use_factorized_embedding: false
|
| 165 |
+
d_embed_rank: 96
|
| 166 |
+
init_std: 0.02
|
| 167 |
+
ssm_core: dplr
|
| 168 |
+
ssm_hidden_dim: 32
|
| 169 |
+
ssm_mixer_dim: 256
|
| 170 |
+
ssm_num_lanes: 2
|
| 171 |
+
ssm_lane_combine: channel
|
| 172 |
+
ssm_lane_mode: split
|
| 173 |
+
ssm_split_mix: none
|
| 174 |
+
ssm_rank: 1
|
| 175 |
+
ssm_max_low_rank_scale: 0.1
|
| 176 |
+
ssm_finite_tail_correction: false
|
| 177 |
+
ssm_discretization: bilinear
|
| 178 |
+
ssm_kernel_mode: conv
|
| 179 |
+
ssm_kernel_threshold: 64
|
| 180 |
+
ssm_dt_min: 1e-3
|
| 181 |
+
ssm_dt_max: 1e-1
|
| 182 |
+
ssm_dt_init: 1e-2
|
| 183 |
+
ssm_use_d: true
|
| 184 |
+
ssm_activation: gelu
|
| 185 |
+
ssm_gate: true
|
| 186 |
+
ssm_input_gate: true
|
| 187 |
+
ssm_gate_type: channel
|
| 188 |
+
ssm_use_padding_mask: false
|
| 189 |
+
ssm_layer_scale_init: 0.1
|
| 190 |
+
ssm_branch_rms_norm: true
|
| 191 |
+
block_residual_rms_norm: false
|
| 192 |
+
$cap_yaml
|
| 193 |
+
ssm_local_shift: true
|
| 194 |
+
ssm_local_shift_init: 0.1
|
| 195 |
+
ssm_local_shift_per_channel: true
|
| 196 |
+
|
| 197 |
+
dataset:
|
| 198 |
+
split: train
|
| 199 |
+
instruction_column: input
|
| 200 |
+
response_column: output
|
| 201 |
+
local: true
|
| 202 |
+
jsonl_path: $SFT_DATA_PATH
|
| 203 |
+
samples_per_chunk: 2000
|
| 204 |
+
tokenizer_type: sentencepiece
|
| 205 |
+
tokenizer_path: $TOKENIZER_PATH
|
| 206 |
+
tokenizer_threads: 8
|
| 207 |
+
|
| 208 |
+
checkpoint_path: $PRETRAIN_CKPT
|
| 209 |
+
user_token: "<user>"
|
| 210 |
+
assistant_token: "<assistant>"
|
| 211 |
+
response_loss_only: true
|
| 212 |
+
|
| 213 |
+
batch_size: $BATCH_SIZE
|
| 214 |
+
num_epochs: 100000
|
| 215 |
+
max_steps: $SFT_SANITY_STEPS
|
| 216 |
+
gradient_accumulation_steps: 1
|
| 217 |
+
max_grad_norm: 1.0
|
| 218 |
+
|
| 219 |
+
optimizer:
|
| 220 |
+
optimizer_type: adamw
|
| 221 |
+
learning_rate: $SFT_SANITY_LR
|
| 222 |
+
weight_decay: 0.0
|
| 223 |
+
betas: [0.9, 0.999]
|
| 224 |
+
eps: 1e-8
|
| 225 |
+
|
| 226 |
+
scheduler:
|
| 227 |
+
scheduler_type: linearWarmup
|
| 228 |
+
warmup_steps: 0
|
| 229 |
+
|
| 230 |
+
dtype: bfloat16
|
| 231 |
+
device: cuda
|
| 232 |
+
checkpoint_dir: $CHECKPOINT_BASE/sft_sanity
|
| 233 |
+
save_every_steps: 5000
|
| 234 |
+
save_best_model: false
|
| 235 |
+
keep_last_n_checkpoints: 1
|
| 236 |
+
eval_every_steps: 5000
|
| 237 |
+
eval_samples: 32
|
| 238 |
+
log_every_steps: 10
|
| 239 |
+
aim_repo: $OUTPUT_BASE/.aim-sft-sanity
|
| 240 |
+
seed: 43
|
| 241 |
+
num_workers: 0
|
| 242 |
+
pin_memory: true
|
| 243 |
+
YAML
|
| 244 |
+
|
| 245 |
+
"$PYTHON_BIN" scripts/diagnostics/sft_sanity_check.py \
|
| 246 |
+
--config "$SFT_CONFIG" \
|
| 247 |
+
--checkpoint "$PRETRAIN_CKPT" \
|
| 248 |
+
--output "$OUTPUT_BASE/diagnostics/sft_sanity_pretrain_latest.json" \
|
| 249 |
+
--samples "$SFT_SANITY_SAMPLES" \
|
| 250 |
+
--steps "$SFT_SANITY_STEPS" \
|
| 251 |
+
--lr "$SFT_SANITY_LR" \
|
| 252 |
+
--log-every 20 \
|
| 253 |
+
--device cuda \
|
| 254 |
+
--dtype bfloat16 \
|
| 255 |
+
--ssm-branch-rms-norm
|
| 256 |
+
|
| 257 |
+
cat > "$OUTPUT_BASE/gate_summary.json" <<JSON
|
| 258 |
+
{
|
| 259 |
+
"state": "completed",
|
| 260 |
+
"pretrain_checkpoint": "$PRETRAIN_CKPT",
|
| 261 |
+
"pretrain_results": "$OUTPUT_BASE/pretrain/taonet_real_token_benchmark.json",
|
| 262 |
+
"activation_probe": "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_latest.json",
|
| 263 |
+
"generation_samples": "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_latest.json",
|
| 264 |
+
"sft_sanity": "$OUTPUT_BASE/diagnostics/sft_sanity_pretrain_latest.json"
|
| 265 |
+
}
|
| 266 |
+
JSON
|
| 267 |
+
|
| 268 |
+
echo "Pre-200M branch-only gate completed."
|
| 269 |
+
cat "$OUTPUT_BASE/gate_summary.json"
|
code/TaoTrain/scripts/remote/run_pre_200m_stability_gate.sh
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
|
| 5 |
+
SFT_DATA_PATH="${SFT_DATA_PATH:-/home/student/Data/TaoData/sft.jsonl}"
|
| 6 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
|
| 7 |
+
SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
|
| 8 |
+
PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
|
| 9 |
+
REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
|
| 10 |
+
OUTPUT_BASE="${REPOBRIDGE_OUTPUT_DIR:-$REMOTE_REPO/results/pre-200m-stability-gate}"
|
| 11 |
+
CHECKPOINT_BASE="${TAOTERN_CHECKPOINT_DIR:-$OUTPUT_BASE/checkpoints}"
|
| 12 |
+
|
| 13 |
+
SEQ_LEN="${SEQ_LEN:-512}"
|
| 14 |
+
BATCH_SIZE="${BATCH_SIZE:-8}"
|
| 15 |
+
TARGET_TOKENS="${TARGET_TOKENS:-20000000}"
|
| 16 |
+
MAX_TOKENS="${MAX_TOKENS:-50000000}"
|
| 17 |
+
MAX_RECORDS="${MAX_RECORDS:-120000}"
|
| 18 |
+
EVAL_BATCHES="${EVAL_BATCHES:-64}"
|
| 19 |
+
LEARNING_RATE="${LEARNING_RATE:-0.0008}"
|
| 20 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
|
| 21 |
+
TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-250}"
|
| 22 |
+
SFT_SANITY_SAMPLES="${SFT_SANITY_SAMPLES:-4}"
|
| 23 |
+
SFT_SANITY_STEPS="${SFT_SANITY_STEPS:-120}"
|
| 24 |
+
SFT_SANITY_LR="${SFT_SANITY_LR:-0.00005}"
|
| 25 |
+
|
| 26 |
+
ceil_div() {
|
| 27 |
+
local numerator="$1"
|
| 28 |
+
local denominator="$2"
|
| 29 |
+
echo $(( (numerator + denominator - 1) / denominator ))
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
TRAIN_STEPS="${TRAIN_STEPS:-$(ceil_div "$TARGET_TOKENS" $((BATCH_SIZE * SEQ_LEN)))}"
|
| 33 |
+
|
| 34 |
+
export PYTHONPATH="$REMOTE_REPO/src:$SSM_REPO_PATH"
|
| 35 |
+
mkdir -p "$OUTPUT_BASE" "$CHECKPOINT_BASE" "$OUTPUT_BASE/configs" "$OUTPUT_BASE/diagnostics"
|
| 36 |
+
|
| 37 |
+
cat > "$OUTPUT_BASE/run_plan.json" <<JSON
|
| 38 |
+
{
|
| 39 |
+
"purpose": "pre_200m_stability_gate_before_4b_sft_chatbot_run",
|
| 40 |
+
"candidate": "pure_ssm_196m_stabilized_m256_h32",
|
| 41 |
+
"target_tokens": $TARGET_TOKENS,
|
| 42 |
+
"train_steps": $TRAIN_STEPS,
|
| 43 |
+
"batch_size": $BATCH_SIZE,
|
| 44 |
+
"seq_len": $SEQ_LEN,
|
| 45 |
+
"learning_rate": $LEARNING_RATE,
|
| 46 |
+
"weight_decay": $WEIGHT_DECAY,
|
| 47 |
+
"checks": [
|
| 48 |
+
"bounded pretrain loss/eval/grad telemetry",
|
| 49 |
+
"activation scale probe",
|
| 50 |
+
"sample generation",
|
| 51 |
+
"tiny SFT overfit probe"
|
| 52 |
+
]
|
| 53 |
+
}
|
| 54 |
+
JSON
|
| 55 |
+
|
| 56 |
+
printf '\n============================================================\n'
|
| 57 |
+
printf 'Pre-200M stability gate: pure SSM stabilized candidate\n'
|
| 58 |
+
printf 'target_tokens=%s batch=%s seq_len=%s train_steps=%s eval_batches=%s\n' \
|
| 59 |
+
"$TARGET_TOKENS" "$BATCH_SIZE" "$SEQ_LEN" "$TRAIN_STEPS" "$EVAL_BATCHES"
|
| 60 |
+
printf '============================================================\n'
|
| 61 |
+
|
| 62 |
+
"$PYTHON_BIN" scripts/benchmark_taonet_real_tokens.py \
|
| 63 |
+
--data-path "$DATA_PATH" \
|
| 64 |
+
--text-field text \
|
| 65 |
+
--tokenizer-type sentencepiece \
|
| 66 |
+
--tokenizer-path "$TOKENIZER_PATH" \
|
| 67 |
+
--max-records "$MAX_RECORDS" \
|
| 68 |
+
--max-tokens "$MAX_TOKENS" \
|
| 69 |
+
--eval-fraction 0.1 \
|
| 70 |
+
--architectures taonet_ssm \
|
| 71 |
+
--batch-sizes "$BATCH_SIZE" \
|
| 72 |
+
--seq-len "$SEQ_LEN" \
|
| 73 |
+
--hidden-dim 1024 \
|
| 74 |
+
--num-layers 18 \
|
| 75 |
+
--num-heads 8 \
|
| 76 |
+
--d-latent-kv 768 \
|
| 77 |
+
--d-rope 128 \
|
| 78 |
+
--hidden-dim-ff 3072 \
|
| 79 |
+
--dropout 0.0 \
|
| 80 |
+
--ssm-core dplr \
|
| 81 |
+
--ssm-hidden-dims 32 \
|
| 82 |
+
--ssm-mixer-dims 256 \
|
| 83 |
+
--ssm-num-lanes-list 2 \
|
| 84 |
+
--ssm-lane-combine channel \
|
| 85 |
+
--ssm-lane-modes split \
|
| 86 |
+
--ssm-split-mixes none \
|
| 87 |
+
--ssm-rank 1 \
|
| 88 |
+
--ssm-kernel-mode conv \
|
| 89 |
+
--no-ssm-finite-tail-correction \
|
| 90 |
+
--ssm-gate-types channel \
|
| 91 |
+
--dtype bf16 \
|
| 92 |
+
--device cuda \
|
| 93 |
+
--warmup 1 \
|
| 94 |
+
--repeats 2 \
|
| 95 |
+
--backward \
|
| 96 |
+
--train-steps "$TRAIN_STEPS" \
|
| 97 |
+
--train-log-every "$TRAIN_LOG_EVERY" \
|
| 98 |
+
--learning-rate "$LEARNING_RATE" \
|
| 99 |
+
--weight-decay "$WEIGHT_DECAY" \
|
| 100 |
+
--max-grad-norm 1.0 \
|
| 101 |
+
--eval-batches "$EVAL_BATCHES" \
|
| 102 |
+
--ssm-local-shift \
|
| 103 |
+
--ssm-local-shift-per-channel \
|
| 104 |
+
--ssm-local-shift-init 0.1 \
|
| 105 |
+
--ssm-branch-rms-norm \
|
| 106 |
+
--ssm-branch-clip-value 1.0 \
|
| 107 |
+
--block-residual-rms-norm \
|
| 108 |
+
--block-residual-rms-target 1.0 \
|
| 109 |
+
--output-dir "$OUTPUT_BASE/pretrain" \
|
| 110 |
+
--resume-completed \
|
| 111 |
+
--incremental-output \
|
| 112 |
+
--save-case-checkpoints \
|
| 113 |
+
--checkpoint-dir "$CHECKPOINT_BASE/pretrain"
|
| 114 |
+
|
| 115 |
+
PRETRAIN_CKPT="$CHECKPOINT_BASE/pretrain/latest.pt"
|
| 116 |
+
if [[ ! -f "$PRETRAIN_CKPT" ]]; then
|
| 117 |
+
echo "Expected pretrain checkpoint missing: $PRETRAIN_CKPT" >&2
|
| 118 |
+
exit 2
|
| 119 |
+
fi
|
| 120 |
+
|
| 121 |
+
"$PYTHON_BIN" scripts/diagnostics/activation_probe.py \
|
| 122 |
+
--checkpoint "$PRETRAIN_CKPT" \
|
| 123 |
+
--tokenizer-path "$TOKENIZER_PATH" \
|
| 124 |
+
--data-path "$DATA_PATH" \
|
| 125 |
+
--text-field text \
|
| 126 |
+
--output "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_latest.json" \
|
| 127 |
+
--batch-size 2 \
|
| 128 |
+
--seq-len "$SEQ_LEN" \
|
| 129 |
+
--device cuda \
|
| 130 |
+
--dtype bfloat16
|
| 131 |
+
|
| 132 |
+
"$PYTHON_BIN" scripts/diagnostics/generate_checkpoint_samples.py \
|
| 133 |
+
--checkpoint "$PRETRAIN_CKPT" \
|
| 134 |
+
--tokenizer-path "$TOKENIZER_PATH" \
|
| 135 |
+
--output "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_latest.json" \
|
| 136 |
+
--max-new-tokens 80 \
|
| 137 |
+
--temperature 0.8 \
|
| 138 |
+
--top-p 0.9 \
|
| 139 |
+
--prompt "The purpose of artificial intelligence is" \
|
| 140 |
+
--prompt "In a small village," \
|
| 141 |
+
--prompt "<user>Hello, who are you?<assistant>"
|
| 142 |
+
|
| 143 |
+
SFT_CONFIG="$OUTPUT_BASE/configs/sft_sanity.yaml"
|
| 144 |
+
cat > "$SFT_CONFIG" <<YAML
|
| 145 |
+
model:
|
| 146 |
+
architecture_type: taonet_ssm
|
| 147 |
+
vocab_size: 8192
|
| 148 |
+
hidden_dim: 1024
|
| 149 |
+
num_layers: 18
|
| 150 |
+
num_heads: 8
|
| 151 |
+
max_seq_length: $SEQ_LEN
|
| 152 |
+
d_latent_kv: 768
|
| 153 |
+
d_rope: 128
|
| 154 |
+
hidden_dim_ff: 3072
|
| 155 |
+
dropout: 0.0
|
| 156 |
+
gqa_groups: 1
|
| 157 |
+
use_factorized_embedding: false
|
| 158 |
+
d_embed_rank: 96
|
| 159 |
+
init_std: 0.02
|
| 160 |
+
ssm_core: dplr
|
| 161 |
+
ssm_hidden_dim: 32
|
| 162 |
+
ssm_mixer_dim: 256
|
| 163 |
+
ssm_num_lanes: 2
|
| 164 |
+
ssm_lane_combine: channel
|
| 165 |
+
ssm_lane_mode: split
|
| 166 |
+
ssm_split_mix: none
|
| 167 |
+
ssm_rank: 1
|
| 168 |
+
ssm_max_low_rank_scale: 0.1
|
| 169 |
+
ssm_finite_tail_correction: false
|
| 170 |
+
ssm_discretization: bilinear
|
| 171 |
+
ssm_kernel_mode: conv
|
| 172 |
+
ssm_kernel_threshold: 64
|
| 173 |
+
ssm_dt_min: 1e-3
|
| 174 |
+
ssm_dt_max: 1e-1
|
| 175 |
+
ssm_dt_init: 1e-2
|
| 176 |
+
ssm_use_d: true
|
| 177 |
+
ssm_activation: gelu
|
| 178 |
+
ssm_gate: true
|
| 179 |
+
ssm_input_gate: true
|
| 180 |
+
ssm_gate_type: channel
|
| 181 |
+
ssm_use_padding_mask: false
|
| 182 |
+
ssm_layer_scale_init: 0.1
|
| 183 |
+
ssm_branch_rms_norm: true
|
| 184 |
+
ssm_branch_clip_value: 1.0
|
| 185 |
+
block_residual_rms_norm: true
|
| 186 |
+
block_residual_rms_target: 1.0
|
| 187 |
+
ssm_local_shift: true
|
| 188 |
+
ssm_local_shift_init: 0.1
|
| 189 |
+
ssm_local_shift_per_channel: true
|
| 190 |
+
|
| 191 |
+
dataset:
|
| 192 |
+
split: train
|
| 193 |
+
instruction_column: input
|
| 194 |
+
response_column: output
|
| 195 |
+
local: true
|
| 196 |
+
jsonl_path: $SFT_DATA_PATH
|
| 197 |
+
samples_per_chunk: 2000
|
| 198 |
+
tokenizer_type: sentencepiece
|
| 199 |
+
tokenizer_path: $TOKENIZER_PATH
|
| 200 |
+
tokenizer_threads: 8
|
| 201 |
+
|
| 202 |
+
checkpoint_path: $PRETRAIN_CKPT
|
| 203 |
+
user_token: "<user>"
|
| 204 |
+
assistant_token: "<assistant>"
|
| 205 |
+
response_loss_only: true
|
| 206 |
+
|
| 207 |
+
batch_size: $BATCH_SIZE
|
| 208 |
+
num_epochs: 100000
|
| 209 |
+
max_steps: $SFT_SANITY_STEPS
|
| 210 |
+
gradient_accumulation_steps: 1
|
| 211 |
+
max_grad_norm: 1.0
|
| 212 |
+
|
| 213 |
+
optimizer:
|
| 214 |
+
optimizer_type: adamw
|
| 215 |
+
learning_rate: $SFT_SANITY_LR
|
| 216 |
+
weight_decay: 0.0
|
| 217 |
+
betas: [0.9, 0.999]
|
| 218 |
+
eps: 1e-8
|
| 219 |
+
|
| 220 |
+
scheduler:
|
| 221 |
+
scheduler_type: linearWarmup
|
| 222 |
+
warmup_steps: 0
|
| 223 |
+
|
| 224 |
+
dtype: bfloat16
|
| 225 |
+
device: cuda
|
| 226 |
+
checkpoint_dir: $CHECKPOINT_BASE/sft_sanity
|
| 227 |
+
save_every_steps: 5000
|
| 228 |
+
save_best_model: false
|
| 229 |
+
keep_last_n_checkpoints: 1
|
| 230 |
+
eval_every_steps: 5000
|
| 231 |
+
eval_samples: 32
|
| 232 |
+
log_every_steps: 10
|
| 233 |
+
aim_repo: $OUTPUT_BASE/.aim-sft-sanity
|
| 234 |
+
seed: 43
|
| 235 |
+
num_workers: 0
|
| 236 |
+
pin_memory: true
|
| 237 |
+
YAML
|
| 238 |
+
|
| 239 |
+
"$PYTHON_BIN" scripts/diagnostics/sft_sanity_check.py \
|
| 240 |
+
--config "$SFT_CONFIG" \
|
| 241 |
+
--checkpoint "$PRETRAIN_CKPT" \
|
| 242 |
+
--output "$OUTPUT_BASE/diagnostics/sft_sanity_pretrain_latest.json" \
|
| 243 |
+
--samples "$SFT_SANITY_SAMPLES" \
|
| 244 |
+
--steps "$SFT_SANITY_STEPS" \
|
| 245 |
+
--lr "$SFT_SANITY_LR" \
|
| 246 |
+
--log-every 20 \
|
| 247 |
+
--device cuda \
|
| 248 |
+
--dtype bfloat16 \
|
| 249 |
+
--ssm-branch-rms-norm \
|
| 250 |
+
--ssm-branch-clip-value 1.0 \
|
| 251 |
+
--block-residual-rms-norm \
|
| 252 |
+
--block-residual-rms-target 1.0
|
| 253 |
+
|
| 254 |
+
cat > "$OUTPUT_BASE/gate_summary.json" <<JSON
|
| 255 |
+
{
|
| 256 |
+
"state": "completed",
|
| 257 |
+
"pretrain_checkpoint": "$PRETRAIN_CKPT",
|
| 258 |
+
"pretrain_results": "$OUTPUT_BASE/pretrain/taonet_real_token_benchmark.json",
|
| 259 |
+
"activation_probe": "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_latest.json",
|
| 260 |
+
"generation_samples": "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_latest.json",
|
| 261 |
+
"sft_sanity": "$OUTPUT_BASE/diagnostics/sft_sanity_pretrain_latest.json"
|
| 262 |
+
}
|
| 263 |
+
JSON
|
| 264 |
+
|
| 265 |
+
echo "Pre-200M stability gate completed."
|
| 266 |
+
cat "$OUTPUT_BASE/gate_summary.json"
|
| 267 |
+
|
code/TaoTrain/scripts/remote/submit_200m_base_suite.sh
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
RUN_ID="${RUN_ID:-taotern-200m-base-suite-$(date +%Y%m%d-%H%M%S)}"
|
| 5 |
+
JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
|
| 6 |
+
REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
|
| 7 |
+
PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
|
| 8 |
+
SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
|
| 9 |
+
|
| 10 |
+
JOB_COMMAND="REMOTE_REPO=$REMOTE_REPO PYTHON_BIN=$PYTHON_BIN SSM_REPO_PATH=$SSM_REPO_PATH bash scripts/remote/run_200m_base_suite.sh"
|
| 11 |
+
|
| 12 |
+
export RUN_ID JOB_ROOT JOB_COMMAND
|
| 13 |
+
export OUTPUT_DIR="${OUTPUT_DIR:-$JOB_ROOT/$RUN_ID/outputs}"
|
| 14 |
+
export CHECKPOINT_DIR="${CHECKPOINT_DIR:-$JOB_ROOT/$RUN_ID/checkpoints}"
|
| 15 |
+
|
| 16 |
+
bash scripts/remote/submit_detached_job.sh
|
code/TaoTrain/scripts/remote/submit_200m_branch_only_chat.sh
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
RUN_ID="${RUN_ID:-taotern-200m-branch-only-chat-$(date +%Y%m%d-%H%M%S)}"
|
| 5 |
+
JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
|
| 6 |
+
REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
|
| 7 |
+
PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
|
| 8 |
+
SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
|
| 9 |
+
DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
|
| 10 |
+
SFT_DATA_PATH="${SFT_DATA_PATH:-/home/student/Data/TaoData/sft.jsonl}"
|
| 11 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
|
| 12 |
+
SEQ_LEN="${SEQ_LEN:-512}"
|
| 13 |
+
BATCH_SIZE="${BATCH_SIZE:-8}"
|
| 14 |
+
PRETRAIN_TOKENS="${PRETRAIN_TOKENS:-4000000000}"
|
| 15 |
+
SFT_STEPS="${SFT_STEPS:-50000}"
|
| 16 |
+
PRETRAIN_LR="${PRETRAIN_LR:-0.0008}"
|
| 17 |
+
SFT_LR="${SFT_LR:-0.00005}"
|
| 18 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
|
| 19 |
+
LOG_EVERY="${LOG_EVERY:-100}"
|
| 20 |
+
SAVE_EVERY="${SAVE_EVERY:-100000}"
|
| 21 |
+
SFT_SAVE_EVERY="${SFT_SAVE_EVERY:-10000}"
|
| 22 |
+
TOKENIZER_THREADS="${TOKENIZER_THREADS:-8}"
|
| 23 |
+
SAMPLES_PER_CHUNK="${SAMPLES_PER_CHUNK:-2000}"
|
| 24 |
+
BLOCK_RESIDUAL_RMS_CAP="${BLOCK_RESIDUAL_RMS_CAP:-}"
|
| 25 |
+
|
| 26 |
+
JOB_COMMAND="REMOTE_REPO=$REMOTE_REPO PYTHON_BIN=$PYTHON_BIN SSM_REPO_PATH=$SSM_REPO_PATH DATA_PATH=$DATA_PATH SFT_DATA_PATH=$SFT_DATA_PATH TOKENIZER_PATH=$TOKENIZER_PATH SEQ_LEN=$SEQ_LEN BATCH_SIZE=$BATCH_SIZE PRETRAIN_TOKENS=$PRETRAIN_TOKENS SFT_STEPS=$SFT_STEPS PRETRAIN_LR=$PRETRAIN_LR SFT_LR=$SFT_LR WEIGHT_DECAY=$WEIGHT_DECAY LOG_EVERY=$LOG_EVERY SAVE_EVERY=$SAVE_EVERY SFT_SAVE_EVERY=$SFT_SAVE_EVERY TOKENIZER_THREADS=$TOKENIZER_THREADS SAMPLES_PER_CHUNK=$SAMPLES_PER_CHUNK BLOCK_RESIDUAL_RMS_CAP=$BLOCK_RESIDUAL_RMS_CAP bash scripts/remote/run_200m_branch_only_chat.sh"
|
| 27 |
+
|
| 28 |
+
export RUN_ID JOB_ROOT JOB_COMMAND
|
| 29 |
+
export OUTPUT_DIR="${OUTPUT_DIR:-$JOB_ROOT/$RUN_ID/outputs}"
|
| 30 |
+
export CHECKPOINT_DIR="${CHECKPOINT_DIR:-$JOB_ROOT/$RUN_ID/checkpoints}"
|
| 31 |
+
|
| 32 |
+
bash scripts/remote/submit_detached_job.sh
|
code/TaoTrain/scripts/remote/submit_200m_until_selection.sh
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
RUN_ID="${RUN_ID:-taotern-200m-until-selection-$(date +%Y%m%d-%H%M%S)}"
|
| 5 |
+
JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
|
| 6 |
+
REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
|
| 7 |
+
PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
|
| 8 |
+
SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
|
| 9 |
+
DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
|
| 10 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
|
| 11 |
+
SEQ_LEN="${SEQ_LEN:-512}"
|
| 12 |
+
BATCH_SIZE="${BATCH_SIZE:-8}"
|
| 13 |
+
PILOT_TOKENS="${PILOT_TOKENS:-300000000}"
|
| 14 |
+
SERIOUS_TOKENS="${SERIOUS_TOKENS:-1000000000}"
|
| 15 |
+
PILOT_EVAL_BATCHES="${PILOT_EVAL_BATCHES:-64}"
|
| 16 |
+
SERIOUS_EVAL_BATCHES="${SERIOUS_EVAL_BATCHES:-128}"
|
| 17 |
+
MAX_TOKENS="${MAX_TOKENS:-50000000}"
|
| 18 |
+
MAX_RECORDS="${MAX_RECORDS:-100000}"
|
| 19 |
+
LEARNING_RATE="${LEARNING_RATE:-0.0003}"
|
| 20 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
|
| 21 |
+
|
| 22 |
+
JOB_COMMAND="REMOTE_REPO=$REMOTE_REPO PYTHON_BIN=$PYTHON_BIN SSM_REPO_PATH=$SSM_REPO_PATH DATA_PATH=$DATA_PATH TOKENIZER_PATH=$TOKENIZER_PATH SEQ_LEN=$SEQ_LEN BATCH_SIZE=$BATCH_SIZE PILOT_TOKENS=$PILOT_TOKENS SERIOUS_TOKENS=$SERIOUS_TOKENS PILOT_EVAL_BATCHES=$PILOT_EVAL_BATCHES SERIOUS_EVAL_BATCHES=$SERIOUS_EVAL_BATCHES MAX_TOKENS=$MAX_TOKENS MAX_RECORDS=$MAX_RECORDS LEARNING_RATE=$LEARNING_RATE WEIGHT_DECAY=$WEIGHT_DECAY bash scripts/remote/run_200m_until_selection.sh"
|
| 23 |
+
|
| 24 |
+
export RUN_ID JOB_ROOT JOB_COMMAND
|
| 25 |
+
export OUTPUT_DIR="${OUTPUT_DIR:-$JOB_ROOT/$RUN_ID/outputs}"
|
| 26 |
+
export CHECKPOINT_DIR="${CHECKPOINT_DIR:-$JOB_ROOT/$RUN_ID/checkpoints}"
|
| 27 |
+
|
| 28 |
+
bash scripts/remote/submit_detached_job.sh
|
code/TaoTrain/scripts/remote/submit_detached_job.sh
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
RUN_ID="${RUN_ID:-}"
|
| 5 |
+
JOB_COMMAND="${JOB_COMMAND:-}"
|
| 6 |
+
JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
|
| 7 |
+
SESSION_NAME="${SESSION_NAME:-}"
|
| 8 |
+
|
| 9 |
+
if [[ -z "$RUN_ID" ]]; then
|
| 10 |
+
echo "RUN_ID is required" >&2
|
| 11 |
+
exit 2
|
| 12 |
+
fi
|
| 13 |
+
|
| 14 |
+
if [[ -z "$JOB_COMMAND" ]]; then
|
| 15 |
+
echo "JOB_COMMAND is required" >&2
|
| 16 |
+
exit 2
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
if ! command -v tmux >/dev/null 2>&1; then
|
| 20 |
+
echo "tmux is required for detached jobs" >&2
|
| 21 |
+
exit 2
|
| 22 |
+
fi
|
| 23 |
+
|
| 24 |
+
safe_run_id="$(printf '%s' "$RUN_ID" | tr -c 'A-Za-z0-9_.-' '_')"
|
| 25 |
+
session="${SESSION_NAME:-taotern_${safe_run_id}}"
|
| 26 |
+
job_dir="${JOB_ROOT%/}/${safe_run_id}"
|
| 27 |
+
output_dir="${OUTPUT_DIR:-$job_dir/outputs}"
|
| 28 |
+
checkpoint_dir="${CHECKPOINT_DIR:-$job_dir/checkpoints}"
|
| 29 |
+
|
| 30 |
+
mkdir -p "$job_dir" "$output_dir" "$checkpoint_dir"
|
| 31 |
+
|
| 32 |
+
cat > "$job_dir/command.sh" <<EOF
|
| 33 |
+
#!/usr/bin/env bash
|
| 34 |
+
set -euo pipefail
|
| 35 |
+
export REPOBRIDGE_OUTPUT_DIR="$output_dir"
|
| 36 |
+
export TAOTERN_CHECKPOINT_DIR="$checkpoint_dir"
|
| 37 |
+
cd "$(pwd)"
|
| 38 |
+
$JOB_COMMAND
|
| 39 |
+
EOF
|
| 40 |
+
chmod +x "$job_dir/command.sh"
|
| 41 |
+
|
| 42 |
+
cat > "$job_dir/run.sh" <<'EOF'
|
| 43 |
+
#!/usr/bin/env bash
|
| 44 |
+
set -Eeuo pipefail
|
| 45 |
+
|
| 46 |
+
job_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 47 |
+
status_json="$job_dir/status.json"
|
| 48 |
+
log_path="$job_dir/train.log"
|
| 49 |
+
telemetry_path="$job_dir/gpu_telemetry_nvidia_smi.csv"
|
| 50 |
+
|
| 51 |
+
write_status() {
|
| 52 |
+
local state="$1"
|
| 53 |
+
local extra="${2:-}"
|
| 54 |
+
local now
|
| 55 |
+
now="$(date -Iseconds)"
|
| 56 |
+
cat > "$status_json" <<STATUS
|
| 57 |
+
{"state":"$state","updated_at":"$now","job_dir":"$job_dir"$extra}
|
| 58 |
+
STATUS
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
write_status "running" ",\"start_time\":\"$(date -Iseconds)\""
|
| 62 |
+
rm -f "$job_dir/DONE" "$job_dir/FAILED"
|
| 63 |
+
|
| 64 |
+
monitor_pid=""
|
| 65 |
+
if command -v nvidia-smi >/dev/null 2>&1; then
|
| 66 |
+
(
|
| 67 |
+
while true; do
|
| 68 |
+
date -Iseconds
|
| 69 |
+
nvidia-smi --query-gpu=timestamp,name,utilization.gpu,utilization.memory,memory.used,memory.total,power.draw,temperature.gpu --format=csv,noheader,nounits
|
| 70 |
+
sleep 5
|
| 71 |
+
done
|
| 72 |
+
) > "$telemetry_path" 2>&1 &
|
| 73 |
+
monitor_pid="$!"
|
| 74 |
+
fi
|
| 75 |
+
|
| 76 |
+
cleanup() {
|
| 77 |
+
if [[ -n "$monitor_pid" ]]; then
|
| 78 |
+
kill "$monitor_pid" 2>/dev/null || true
|
| 79 |
+
wait "$monitor_pid" 2>/dev/null || true
|
| 80 |
+
fi
|
| 81 |
+
}
|
| 82 |
+
trap cleanup EXIT
|
| 83 |
+
|
| 84 |
+
set +e
|
| 85 |
+
"$job_dir/command.sh" > "$log_path" 2>&1
|
| 86 |
+
exit_code="$?"
|
| 87 |
+
set -e
|
| 88 |
+
|
| 89 |
+
if [[ "$exit_code" -eq 0 ]]; then
|
| 90 |
+
touch "$job_dir/DONE"
|
| 91 |
+
write_status "completed" ",\"exit_code\":0,\"end_time\":\"$(date -Iseconds)\""
|
| 92 |
+
else
|
| 93 |
+
echo "$exit_code" > "$job_dir/FAILED"
|
| 94 |
+
write_status "failed" ",\"exit_code\":$exit_code,\"end_time\":\"$(date -Iseconds)\""
|
| 95 |
+
fi
|
| 96 |
+
|
| 97 |
+
exit "$exit_code"
|
| 98 |
+
EOF
|
| 99 |
+
chmod +x "$job_dir/run.sh"
|
| 100 |
+
|
| 101 |
+
cat > "$job_dir/status.json" <<STATUS
|
| 102 |
+
{"state":"submitted","updated_at":"$(date -Iseconds)","job_dir":"$job_dir","session":"$session","output_dir":"$output_dir","checkpoint_dir":"$checkpoint_dir"}
|
| 103 |
+
STATUS
|
| 104 |
+
|
| 105 |
+
if tmux has-session -t "$session" 2>/dev/null; then
|
| 106 |
+
echo "tmux session already exists: $session" >&2
|
| 107 |
+
echo "$job_dir"
|
| 108 |
+
exit 3
|
| 109 |
+
fi
|
| 110 |
+
|
| 111 |
+
tmux new-session -d -s "$session" "bash '$job_dir/run.sh'"
|
| 112 |
+
|
| 113 |
+
echo "Submitted detached job"
|
| 114 |
+
echo " run_id: $safe_run_id"
|
| 115 |
+
echo " session: $session"
|
| 116 |
+
echo " job_dir: $job_dir"
|
| 117 |
+
echo " output_dir: $output_dir"
|
| 118 |
+
echo " checkpoint_dir: $checkpoint_dir"
|
code/TaoTrain/scripts/remote/submit_pre_200m_branch_only_gate.sh
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
RUN_ID="${RUN_ID:-taotern-pre-200m-branch-only-gate-$(date +%Y%m%d-%H%M%S)}"
|
| 5 |
+
JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
|
| 6 |
+
REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
|
| 7 |
+
PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
|
| 8 |
+
SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
|
| 9 |
+
DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
|
| 10 |
+
SFT_DATA_PATH="${SFT_DATA_PATH:-/home/student/Data/TaoData/sft.jsonl}"
|
| 11 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
|
| 12 |
+
SEQ_LEN="${SEQ_LEN:-512}"
|
| 13 |
+
BATCH_SIZE="${BATCH_SIZE:-8}"
|
| 14 |
+
TARGET_TOKENS="${TARGET_TOKENS:-100000000}"
|
| 15 |
+
MAX_TOKENS="${MAX_TOKENS:-100000000}"
|
| 16 |
+
MAX_RECORDS="${MAX_RECORDS:-200000}"
|
| 17 |
+
EVAL_BATCHES="${EVAL_BATCHES:-128}"
|
| 18 |
+
LEARNING_RATE="${LEARNING_RATE:-0.0008}"
|
| 19 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
|
| 20 |
+
TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-1000}"
|
| 21 |
+
SFT_SANITY_SAMPLES="${SFT_SANITY_SAMPLES:-4}"
|
| 22 |
+
SFT_SANITY_STEPS="${SFT_SANITY_STEPS:-120}"
|
| 23 |
+
SFT_SANITY_LR="${SFT_SANITY_LR:-0.00005}"
|
| 24 |
+
BLOCK_RESIDUAL_RMS_CAP="${BLOCK_RESIDUAL_RMS_CAP:-}"
|
| 25 |
+
|
| 26 |
+
JOB_COMMAND="REMOTE_REPO=$REMOTE_REPO PYTHON_BIN=$PYTHON_BIN SSM_REPO_PATH=$SSM_REPO_PATH DATA_PATH=$DATA_PATH SFT_DATA_PATH=$SFT_DATA_PATH TOKENIZER_PATH=$TOKENIZER_PATH SEQ_LEN=$SEQ_LEN BATCH_SIZE=$BATCH_SIZE TARGET_TOKENS=$TARGET_TOKENS MAX_TOKENS=$MAX_TOKENS MAX_RECORDS=$MAX_RECORDS EVAL_BATCHES=$EVAL_BATCHES LEARNING_RATE=$LEARNING_RATE WEIGHT_DECAY=$WEIGHT_DECAY TRAIN_LOG_EVERY=$TRAIN_LOG_EVERY SFT_SANITY_SAMPLES=$SFT_SANITY_SAMPLES SFT_SANITY_STEPS=$SFT_SANITY_STEPS SFT_SANITY_LR=$SFT_SANITY_LR BLOCK_RESIDUAL_RMS_CAP=$BLOCK_RESIDUAL_RMS_CAP bash scripts/remote/run_pre_200m_branch_only_gate.sh"
|
| 27 |
+
|
| 28 |
+
export RUN_ID JOB_ROOT JOB_COMMAND
|
| 29 |
+
export OUTPUT_DIR="${OUTPUT_DIR:-$JOB_ROOT/$RUN_ID/outputs}"
|
| 30 |
+
export CHECKPOINT_DIR="${CHECKPOINT_DIR:-$JOB_ROOT/$RUN_ID/checkpoints}"
|
| 31 |
+
|
| 32 |
+
bash scripts/remote/submit_detached_job.sh
|
| 33 |
+
|
code/TaoTrain/scripts/remote/submit_pre_200m_stability_gate.sh
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
RUN_ID="${RUN_ID:-taotern-pre-200m-stability-gate-$(date +%Y%m%d-%H%M%S)}"
|
| 5 |
+
JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
|
| 6 |
+
REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
|
| 7 |
+
PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
|
| 8 |
+
SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
|
| 9 |
+
DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
|
| 10 |
+
SFT_DATA_PATH="${SFT_DATA_PATH:-/home/student/Data/TaoData/sft.jsonl}"
|
| 11 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
|
| 12 |
+
SEQ_LEN="${SEQ_LEN:-512}"
|
| 13 |
+
BATCH_SIZE="${BATCH_SIZE:-8}"
|
| 14 |
+
TARGET_TOKENS="${TARGET_TOKENS:-20000000}"
|
| 15 |
+
MAX_TOKENS="${MAX_TOKENS:-50000000}"
|
| 16 |
+
MAX_RECORDS="${MAX_RECORDS:-120000}"
|
| 17 |
+
EVAL_BATCHES="${EVAL_BATCHES:-64}"
|
| 18 |
+
LEARNING_RATE="${LEARNING_RATE:-0.0008}"
|
| 19 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
|
| 20 |
+
TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-250}"
|
| 21 |
+
SFT_SANITY_SAMPLES="${SFT_SANITY_SAMPLES:-4}"
|
| 22 |
+
SFT_SANITY_STEPS="${SFT_SANITY_STEPS:-120}"
|
| 23 |
+
SFT_SANITY_LR="${SFT_SANITY_LR:-0.00005}"
|
| 24 |
+
|
| 25 |
+
JOB_COMMAND="REMOTE_REPO=$REMOTE_REPO PYTHON_BIN=$PYTHON_BIN SSM_REPO_PATH=$SSM_REPO_PATH DATA_PATH=$DATA_PATH SFT_DATA_PATH=$SFT_DATA_PATH TOKENIZER_PATH=$TOKENIZER_PATH SEQ_LEN=$SEQ_LEN BATCH_SIZE=$BATCH_SIZE TARGET_TOKENS=$TARGET_TOKENS MAX_TOKENS=$MAX_TOKENS MAX_RECORDS=$MAX_RECORDS EVAL_BATCHES=$EVAL_BATCHES LEARNING_RATE=$LEARNING_RATE WEIGHT_DECAY=$WEIGHT_DECAY TRAIN_LOG_EVERY=$TRAIN_LOG_EVERY SFT_SANITY_SAMPLES=$SFT_SANITY_SAMPLES SFT_SANITY_STEPS=$SFT_SANITY_STEPS SFT_SANITY_LR=$SFT_SANITY_LR bash scripts/remote/run_pre_200m_stability_gate.sh"
|
| 26 |
+
|
| 27 |
+
export RUN_ID JOB_ROOT JOB_COMMAND
|
| 28 |
+
export OUTPUT_DIR="${OUTPUT_DIR:-$JOB_ROOT/$RUN_ID/outputs}"
|
| 29 |
+
export CHECKPOINT_DIR="${CHECKPOINT_DIR:-$JOB_ROOT/$RUN_ID/checkpoints}"
|
| 30 |
+
|
| 31 |
+
bash scripts/remote/submit_detached_job.sh
|
| 32 |
+
|
code/TaoTrain/scripts/remote/submit_ssm_improvement_sweep.sh
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
RUN_ID="${RUN_ID:-ssm-improvement-sweep-$(date +%Y%m%d-%H%M%S)}"
|
| 5 |
+
DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
|
| 6 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
|
| 7 |
+
SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
|
| 8 |
+
PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
|
| 9 |
+
JOB_ROOT="${JOB_ROOT:-/home/student/YouZheng/jobs/taotern}"
|
| 10 |
+
REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
|
| 11 |
+
|
| 12 |
+
# This is intentionally below the eventual 200M scale. It is a model-selection
|
| 13 |
+
# sweep that can run unattended and resume completed benchmark cases.
|
| 14 |
+
MAX_TOKENS="${MAX_TOKENS:-100000000}"
|
| 15 |
+
MAX_RECORDS="${MAX_RECORDS:-150000}"
|
| 16 |
+
TRAIN_STEPS="${TRAIN_STEPS:-5000}"
|
| 17 |
+
EVAL_BATCHES="${EVAL_BATCHES:-96}"
|
| 18 |
+
BATCH_SIZES="${BATCH_SIZES:-32,64}"
|
| 19 |
+
SEQ_LEN="${SEQ_LEN:-512}"
|
| 20 |
+
|
| 21 |
+
JOB_COMMAND="PYTHONPATH=$REMOTE_REPO/src:$SSM_REPO_PATH $PYTHON_BIN scripts/benchmark_taonet_real_tokens.py \
|
| 22 |
+
--data-path $DATA_PATH \
|
| 23 |
+
--text-field text \
|
| 24 |
+
--tokenizer-type sentencepiece \
|
| 25 |
+
--tokenizer-path $TOKENIZER_PATH \
|
| 26 |
+
--max-records $MAX_RECORDS \
|
| 27 |
+
--max-tokens $MAX_TOKENS \
|
| 28 |
+
--eval-fraction 0.1 \
|
| 29 |
+
--architectures taonet,taonet_ssm,taonet_hybrid \
|
| 30 |
+
--batch-sizes $BATCH_SIZES \
|
| 31 |
+
--seq-len $SEQ_LEN \
|
| 32 |
+
--hidden-dim 256 \
|
| 33 |
+
--num-layers 4 \
|
| 34 |
+
--num-heads 4 \
|
| 35 |
+
--d-latent-kv 192 \
|
| 36 |
+
--hidden-dim-ff 1024 \
|
| 37 |
+
--ssm-core dplr \
|
| 38 |
+
--ssm-hidden-dims 16,32 \
|
| 39 |
+
--ssm-mixer-dims 128,256 \
|
| 40 |
+
--ssm-num-lanes-list 1,2 \
|
| 41 |
+
--ssm-lane-combine channel \
|
| 42 |
+
--ssm-lane-modes full,split \
|
| 43 |
+
--ssm-split-mixes none,hadamard \
|
| 44 |
+
--ssm-rank 1 \
|
| 45 |
+
--ssm-kernel-mode conv \
|
| 46 |
+
--no-ssm-finite-tail-correction \
|
| 47 |
+
--ssm-gate-types channel \
|
| 48 |
+
--hybrid-patterns attention_first,ssm_first,single_ssm_middle,single_ssm_late \
|
| 49 |
+
--dtype bf16 \
|
| 50 |
+
--device cuda \
|
| 51 |
+
--warmup 2 \
|
| 52 |
+
--repeats 3 \
|
| 53 |
+
--backward \
|
| 54 |
+
--train-steps $TRAIN_STEPS \
|
| 55 |
+
--learning-rate 0.0008 \
|
| 56 |
+
--weight-decay 0.01 \
|
| 57 |
+
--eval-batches $EVAL_BATCHES \
|
| 58 |
+
--ssm-local-shift \
|
| 59 |
+
--ssm-local-shift-per-channel \
|
| 60 |
+
--ssm-local-shift-init 0.1 \
|
| 61 |
+
--output-dir \"\$REPOBRIDGE_OUTPUT_DIR\" \
|
| 62 |
+
--resume-completed \
|
| 63 |
+
--incremental-output \
|
| 64 |
+
--save-case-checkpoints \
|
| 65 |
+
--checkpoint-dir \"\$TAOTERN_CHECKPOINT_DIR\""
|
| 66 |
+
|
| 67 |
+
export RUN_ID JOB_ROOT JOB_COMMAND
|
| 68 |
+
export OUTPUT_DIR="${OUTPUT_DIR:-$JOB_ROOT/$RUN_ID/outputs}"
|
| 69 |
+
export CHECKPOINT_DIR="${CHECKPOINT_DIR:-$JOB_ROOT/$RUN_ID/checkpoints}"
|
| 70 |
+
|
| 71 |
+
bash scripts/remote/submit_detached_job.sh
|
code/TaoTrain/scripts/summarize_taonet_benchmark_suite.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import csv
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _as_float(value: str | None) -> float | None:
|
| 11 |
+
if value is None or value == "":
|
| 12 |
+
return None
|
| 13 |
+
try:
|
| 14 |
+
return float(value)
|
| 15 |
+
except ValueError:
|
| 16 |
+
return None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _load_rows(root: Path) -> list[dict[str, Any]]:
|
| 20 |
+
rows: list[dict[str, Any]] = []
|
| 21 |
+
for csv_path in sorted(root.glob("*/taonet_real_token_benchmark.csv")):
|
| 22 |
+
variant = csv_path.parent.name
|
| 23 |
+
with csv_path.open("r", newline="", encoding="utf-8") as handle:
|
| 24 |
+
for row in csv.DictReader(handle):
|
| 25 |
+
row = dict(row)
|
| 26 |
+
row["variant"] = variant
|
| 27 |
+
rows.append(row)
|
| 28 |
+
return rows
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _best_forward_backward(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 32 |
+
candidates = [row for row in rows if row.get("mode") == "forward_backward"]
|
| 33 |
+
grouped: dict[str, list[dict[str, Any]]] = {}
|
| 34 |
+
for row in candidates:
|
| 35 |
+
grouped.setdefault(row["variant"], []).append(row)
|
| 36 |
+
|
| 37 |
+
best_rows = []
|
| 38 |
+
for variant, items in grouped.items():
|
| 39 |
+
items.sort(
|
| 40 |
+
key=lambda row: (
|
| 41 |
+
_as_float(row.get("eval_loss")) if _as_float(row.get("eval_loss")) is not None else float("inf"),
|
| 42 |
+
-(_as_float(row.get("eval_accuracy")) or 0.0),
|
| 43 |
+
)
|
| 44 |
+
)
|
| 45 |
+
best_rows.append(items[0])
|
| 46 |
+
best_rows.sort(
|
| 47 |
+
key=lambda row: (
|
| 48 |
+
_as_float(row.get("eval_loss")) if _as_float(row.get("eval_loss")) is not None else float("inf"),
|
| 49 |
+
-(_as_float(row.get("eval_accuracy")) or 0.0),
|
| 50 |
+
)
|
| 51 |
+
)
|
| 52 |
+
return best_rows
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _project(row: dict[str, Any]) -> dict[str, Any]:
|
| 56 |
+
keys = [
|
| 57 |
+
"variant",
|
| 58 |
+
"architecture",
|
| 59 |
+
"hybrid_pattern",
|
| 60 |
+
"batch_size",
|
| 61 |
+
"seq_len",
|
| 62 |
+
"total_params",
|
| 63 |
+
"ssm_core",
|
| 64 |
+
"ssm_hidden_dim",
|
| 65 |
+
"ssm_mixer_dim",
|
| 66 |
+
"ssm_num_lanes",
|
| 67 |
+
"ssm_lane_mode",
|
| 68 |
+
"ssm_split_mix",
|
| 69 |
+
"tokens_per_s_mean",
|
| 70 |
+
"eval_loss",
|
| 71 |
+
"eval_perplexity",
|
| 72 |
+
"eval_accuracy",
|
| 73 |
+
"train_final_loss",
|
| 74 |
+
"train_seconds",
|
| 75 |
+
"peak_reserved_mb",
|
| 76 |
+
"case_id",
|
| 77 |
+
"checkpoint_path",
|
| 78 |
+
]
|
| 79 |
+
return {key: row.get(key, "") for key in keys}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _write_markdown(summary: list[dict[str, Any]], path: Path) -> None:
|
| 83 |
+
headers = [
|
| 84 |
+
"variant",
|
| 85 |
+
"architecture",
|
| 86 |
+
"batch",
|
| 87 |
+
"params",
|
| 88 |
+
"eval_loss",
|
| 89 |
+
"eval_acc",
|
| 90 |
+
"tok/s",
|
| 91 |
+
"checkpoint",
|
| 92 |
+
]
|
| 93 |
+
lines = [
|
| 94 |
+
"# TaoNet Benchmark Suite Summary",
|
| 95 |
+
"",
|
| 96 |
+
"| " + " | ".join(headers) + " |",
|
| 97 |
+
"| " + " | ".join(["---"] * len(headers)) + " |",
|
| 98 |
+
]
|
| 99 |
+
for row in summary:
|
| 100 |
+
lines.append(
|
| 101 |
+
"| "
|
| 102 |
+
+ " | ".join(
|
| 103 |
+
[
|
| 104 |
+
str(row["variant"]),
|
| 105 |
+
str(row["architecture"]),
|
| 106 |
+
str(row["batch_size"]),
|
| 107 |
+
str(row["total_params"]),
|
| 108 |
+
str(row["eval_loss"]),
|
| 109 |
+
str(row["eval_accuracy"]),
|
| 110 |
+
str(row["tokens_per_s_mean"]),
|
| 111 |
+
str(row["checkpoint_path"]),
|
| 112 |
+
]
|
| 113 |
+
)
|
| 114 |
+
+ " |"
|
| 115 |
+
)
|
| 116 |
+
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def main() -> None:
|
| 120 |
+
parser = argparse.ArgumentParser(description="Summarize a TaoNet benchmark suite output directory.")
|
| 121 |
+
parser.add_argument("--suite-dir", required=True, help="Directory containing one subdirectory per benchmark variant.")
|
| 122 |
+
parser.add_argument("--output-json", default="", help="Summary JSON path. Defaults to <suite-dir>/suite_summary.json.")
|
| 123 |
+
parser.add_argument("--output-md", default="", help="Summary Markdown path. Defaults to <suite-dir>/suite_summary.md.")
|
| 124 |
+
args = parser.parse_args()
|
| 125 |
+
|
| 126 |
+
suite_dir = Path(args.suite_dir)
|
| 127 |
+
rows = _load_rows(suite_dir)
|
| 128 |
+
summary = [_project(row) for row in _best_forward_backward(rows)]
|
| 129 |
+
json_path = Path(args.output_json) if args.output_json else suite_dir / "suite_summary.json"
|
| 130 |
+
md_path = Path(args.output_md) if args.output_md else suite_dir / "suite_summary.md"
|
| 131 |
+
json_path.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8")
|
| 132 |
+
_write_markdown(summary, md_path)
|
| 133 |
+
print(f"Wrote {json_path}")
|
| 134 |
+
print(f"Wrote {md_path}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
main()
|
code/TaoTrain/src/taoTrain.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
code/TaoTrain/src/taoTrain.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
train = taoTrain.cli:main
|
| 3 |
+
train-tokenizer = taoTrain.cli:train_tokenizer_command
|
| 4 |
+
tui-chat = taoTrain.inference.tui:main
|
code/TaoTrain/src/taoTrain/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TaoTrain: A clean, modular PyTorch LLM training framework.
|
| 3 |
+
|
| 4 |
+
For quick usage, see: https://github.com/your-org/taoTrain
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
__version__ = "0.1.0"
|
| 8 |
+
__author__ = "Felix"
|
code/TaoTrain/src/taoTrain/cli.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Main CLI entry point."""
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import click
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from taoTrain.config import load_config, load_tokenizer_config, TrainingModeEnum, PretrainConfig, SFTConfig, RLConfig
|
| 10 |
+
from taoTrain.utils import set_seed, get_device
|
| 11 |
+
from taoTrain.core import BaseModel, create_model, create_datasets
|
| 12 |
+
from taoTrain.data import get_dataloader
|
| 13 |
+
from taoTrain.training import PretrainTrainer, SFTTrainer, RLTrainer
|
| 14 |
+
from taoTrain.benchmarks import BenchmarkRunner
|
| 15 |
+
from taoTrain.tokenizers import TokenizerTrainer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@click.group()
|
| 19 |
+
def main():
|
| 20 |
+
"""TaoTrain: A clean, modular PyTorch LLM training framework."""
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@main.command()
|
| 25 |
+
@click.option(
|
| 26 |
+
"--config",
|
| 27 |
+
type=click.Path(exists=True),
|
| 28 |
+
required=True,
|
| 29 |
+
help="Path to training config file (YAML or JSON)",
|
| 30 |
+
)
|
| 31 |
+
def pretrain(config: str):
|
| 32 |
+
"""Pretrain a language model."""
|
| 33 |
+
_train_command(config, TrainingModeEnum.PRETRAIN)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@main.command()
|
| 37 |
+
@click.option(
|
| 38 |
+
"--config",
|
| 39 |
+
type=click.Path(exists=True),
|
| 40 |
+
required=True,
|
| 41 |
+
help="Path to training config file (YAML or JSON)",
|
| 42 |
+
)
|
| 43 |
+
def sft(config: str):
|
| 44 |
+
"""Supervised fine-tune a language model."""
|
| 45 |
+
_train_command(config, TrainingModeEnum.SFT)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@main.command()
|
| 49 |
+
@click.option(
|
| 50 |
+
"--config",
|
| 51 |
+
type=click.Path(exists=True),
|
| 52 |
+
required=True,
|
| 53 |
+
help="Path to training config file (YAML or JSON)",
|
| 54 |
+
)
|
| 55 |
+
def rl(config: str):
|
| 56 |
+
"""Train with reinforcement learning."""
|
| 57 |
+
_train_command(config, TrainingModeEnum.RL)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@main.command()
|
| 61 |
+
@click.option(
|
| 62 |
+
"--config",
|
| 63 |
+
type=click.Path(exists=True),
|
| 64 |
+
required=True,
|
| 65 |
+
help="Path to tokenizer config file (YAML or JSON)",
|
| 66 |
+
)
|
| 67 |
+
def train_tokenizer(config: str):
|
| 68 |
+
"""Train a SentencePiece tokenizer from a YAML/JSON config file."""
|
| 69 |
+
try:
|
| 70 |
+
click.echo("🚀 TaoTrain Tokenizer Trainer")
|
| 71 |
+
click.echo(f"{'=' * 50}")
|
| 72 |
+
|
| 73 |
+
# Load tokenizer config
|
| 74 |
+
click.echo(f"Loading config from {config}...")
|
| 75 |
+
tokenizer_config = load_tokenizer_config(config)
|
| 76 |
+
|
| 77 |
+
# Train tokenizer from config
|
| 78 |
+
result = TokenizerTrainer.train_from_config(tokenizer_config)
|
| 79 |
+
|
| 80 |
+
# Display results
|
| 81 |
+
click.echo(f"\n{'=' * 50}")
|
| 82 |
+
click.echo("✅ Tokenizer Training Complete!")
|
| 83 |
+
click.echo(f"\n📊 Configuration:")
|
| 84 |
+
click.echo(f" - Input file: {tokenizer_config.jsonl_path}")
|
| 85 |
+
click.echo(f" - Samples: {tokenizer_config.max_samples or 'all'}")
|
| 86 |
+
click.echo(f" - Output dir: {result['output_dir']}")
|
| 87 |
+
click.echo(f" - Vocab size: {result['vocab_size']}")
|
| 88 |
+
click.echo(f" - Model type: {result['model_type']}")
|
| 89 |
+
|
| 90 |
+
if tokenizer_config.special_tokens:
|
| 91 |
+
click.echo(f" - Special tokens: {tokenizer_config.special_tokens}")
|
| 92 |
+
|
| 93 |
+
click.echo(f"\n📁 Generated Files:")
|
| 94 |
+
click.echo(f" - Model: {result['model_file']}")
|
| 95 |
+
click.echo(f" - Vocab: {result['vocab_file']}")
|
| 96 |
+
|
| 97 |
+
click.echo(f"\n📝 Next Steps:")
|
| 98 |
+
click.echo(f" 1. Use this tokenizer in your pretraining config:")
|
| 99 |
+
click.echo(f" dataset:")
|
| 100 |
+
click.echo(f" local: true")
|
| 101 |
+
click.echo(f" jsonl_path: {tokenizer_config.jsonl_path}")
|
| 102 |
+
click.echo(f" tokenizer_path: {result['model_file']}")
|
| 103 |
+
click.echo(f"")
|
| 104 |
+
click.echo(f" 2. Run pretraining with:")
|
| 105 |
+
click.echo(f" train pretrain --config your_config.yaml")
|
| 106 |
+
|
| 107 |
+
except ImportError as e:
|
| 108 |
+
click.echo(f"❌ Error: {e}", err=True)
|
| 109 |
+
sys.exit(1)
|
| 110 |
+
except FileNotFoundError as e:
|
| 111 |
+
click.echo(f"❌ File Error: {e}", err=True)
|
| 112 |
+
sys.exit(1)
|
| 113 |
+
except ValueError as e:
|
| 114 |
+
click.echo(f"❌ Validation Error: {e}", err=True)
|
| 115 |
+
sys.exit(1)
|
| 116 |
+
except Exception as e:
|
| 117 |
+
click.echo(f"❌ Unexpected Error: {e}", err=True)
|
| 118 |
+
click.echo(f" Please report this issue.", err=True)
|
| 119 |
+
sys.exit(1)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@click.command()
|
| 123 |
+
@click.option(
|
| 124 |
+
"--jsonl-path",
|
| 125 |
+
type=click.Path(exists=True),
|
| 126 |
+
required=True,
|
| 127 |
+
help="Path to JSONL file containing training data",
|
| 128 |
+
)
|
| 129 |
+
@click.option(
|
| 130 |
+
"--output-dir",
|
| 131 |
+
type=click.Path(),
|
| 132 |
+
default="tokenizers",
|
| 133 |
+
help="Directory to save tokenizer files",
|
| 134 |
+
)
|
| 135 |
+
@click.option(
|
| 136 |
+
"--vocab-size",
|
| 137 |
+
type=int,
|
| 138 |
+
default=50000,
|
| 139 |
+
help="Vocabulary size for the tokenizer",
|
| 140 |
+
)
|
| 141 |
+
@click.option(
|
| 142 |
+
"--model-type",
|
| 143 |
+
type=click.Choice(["unigram", "bpe", "char", "word"]),
|
| 144 |
+
default="unigram",
|
| 145 |
+
help="SentencePiece model type",
|
| 146 |
+
)
|
| 147 |
+
@click.option(
|
| 148 |
+
"--character-coverage",
|
| 149 |
+
type=float,
|
| 150 |
+
default=0.9995,
|
| 151 |
+
help="Character coverage for SentencePiece",
|
| 152 |
+
)
|
| 153 |
+
@click.option(
|
| 154 |
+
"--tokenizer-prefix",
|
| 155 |
+
type=str,
|
| 156 |
+
default=None,
|
| 157 |
+
help="Prefix for tokenizer output files (default: model_type)",
|
| 158 |
+
)
|
| 159 |
+
def train_tokenizer_command(
|
| 160 |
+
jsonl_path: str,
|
| 161 |
+
output_dir: str,
|
| 162 |
+
vocab_size: int,
|
| 163 |
+
model_type: str,
|
| 164 |
+
character_coverage: float,
|
| 165 |
+
tokenizer_prefix: Optional[str],
|
| 166 |
+
):
|
| 167 |
+
"""Train a SentencePiece tokenizer from JSONL data."""
|
| 168 |
+
try:
|
| 169 |
+
click.echo("🚀 TaoTrain Tokenizer Trainer")
|
| 170 |
+
click.echo(f"{'=' * 50}")
|
| 171 |
+
|
| 172 |
+
# Train tokenizer
|
| 173 |
+
result = TokenizerTrainer.train_sentencepiece(
|
| 174 |
+
jsonl_path=jsonl_path,
|
| 175 |
+
output_dir=output_dir,
|
| 176 |
+
vocab_size=vocab_size,
|
| 177 |
+
model_type=model_type,
|
| 178 |
+
character_coverage=character_coverage,
|
| 179 |
+
tokenizer_prefix=tokenizer_prefix,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Display results
|
| 183 |
+
click.echo(f"\n{'=' * 50}")
|
| 184 |
+
click.echo("✅ Tokenizer Training Complete!")
|
| 185 |
+
click.echo(f"\n📊 Configuration:")
|
| 186 |
+
click.echo(f" - Input file: {jsonl_path}")
|
| 187 |
+
click.echo(f" - Output dir: {result['output_dir']}")
|
| 188 |
+
click.echo(f" - Vocab size: {result['vocab_size']}")
|
| 189 |
+
click.echo(f" - Model type: {result['model_type']}")
|
| 190 |
+
|
| 191 |
+
click.echo(f"\n📁 Generated Files:")
|
| 192 |
+
click.echo(f" - Model: {result['model_file']}")
|
| 193 |
+
click.echo(f" - Vocab: {result['vocab_file']}")
|
| 194 |
+
|
| 195 |
+
click.echo(f"\n📝 Next Steps:")
|
| 196 |
+
click.echo(f" 1. Use this tokenizer in your pretraining config:")
|
| 197 |
+
click.echo(f" dataset:")
|
| 198 |
+
click.echo(f" local: true")
|
| 199 |
+
click.echo(f" jsonl_path: {jsonl_path}")
|
| 200 |
+
click.echo(f" tokenizer_path: {result['model_file']}")
|
| 201 |
+
click.echo(f"")
|
| 202 |
+
click.echo(f" 2. Run pretraining with:")
|
| 203 |
+
click.echo(f" train pretrain --config your_config.yaml")
|
| 204 |
+
|
| 205 |
+
except ImportError as e:
|
| 206 |
+
click.echo(f"❌ Error: {e}", err=True)
|
| 207 |
+
sys.exit(1)
|
| 208 |
+
except FileNotFoundError as e:
|
| 209 |
+
click.echo(f"❌ File Error: {e}", err=True)
|
| 210 |
+
sys.exit(1)
|
| 211 |
+
except ValueError as e:
|
| 212 |
+
click.echo(f"❌ Validation Error: {e}", err=True)
|
| 213 |
+
sys.exit(1)
|
| 214 |
+
except Exception as e:
|
| 215 |
+
click.echo(f"❌ Unexpected Error: {e}", err=True)
|
| 216 |
+
click.echo(f" Please report this issue.", err=True)
|
| 217 |
+
sys.exit(1)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Keep legacy CLI command as train-tokenizer-legacy for backward compatibility
|
| 221 |
+
main.add_command(train_tokenizer_command, name="train-tokenizer-legacy")
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _train_command(config_path: str, mode: TrainingModeEnum):
|
| 225 |
+
"""Internal training command."""
|
| 226 |
+
try:
|
| 227 |
+
# Load config
|
| 228 |
+
click.echo(f"Loading config from {config_path}...")
|
| 229 |
+
train_config = load_config(config_path, mode)
|
| 230 |
+
|
| 231 |
+
# Set seed
|
| 232 |
+
set_seed(train_config.seed)
|
| 233 |
+
|
| 234 |
+
# Get device
|
| 235 |
+
device = get_device(train_config.device)
|
| 236 |
+
click.echo(f"Using device: {device}")
|
| 237 |
+
|
| 238 |
+
# Create model
|
| 239 |
+
click.echo("Creating model...")
|
| 240 |
+
model = create_model(train_config, device)
|
| 241 |
+
total_params, trainable_params = _count_params(model)
|
| 242 |
+
click.echo(f" - Total parameters: {total_params:,}")
|
| 243 |
+
click.echo(f" - Trainable parameters: {trainable_params:,}")
|
| 244 |
+
|
| 245 |
+
# Load pretrained checkpoint if provided (for SFT/RL)
|
| 246 |
+
if train_config.checkpoint_path:
|
| 247 |
+
click.echo(f"Loading pretrained checkpoint from {train_config.checkpoint_path}...")
|
| 248 |
+
from taoTrain.checkpointing.checkpoint import CheckpointManager
|
| 249 |
+
checkpoint_manager = CheckpointManager(train_config.checkpoint_dir)
|
| 250 |
+
checkpoint = checkpoint_manager.load(train_config.checkpoint_path, device=device)
|
| 251 |
+
|
| 252 |
+
# CheckpointManager.load() normalizes format and ensures 'model_state' key exists
|
| 253 |
+
if "model_state" in checkpoint:
|
| 254 |
+
model.load_state_dict(checkpoint["model_state"], strict=False)
|
| 255 |
+
click.echo(" ✓ Checkpoint loaded successfully")
|
| 256 |
+
else:
|
| 257 |
+
raise KeyError(f"Invalid checkpoint format: 'model_state' key not found. "
|
| 258 |
+
f"Available keys: {list(checkpoint.keys())}")
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# Create datasets
|
| 263 |
+
click.echo("Loading datasets...")
|
| 264 |
+
train_dataset, val_dataset = create_datasets(train_config)
|
| 265 |
+
click.echo(f" - Train samples: {len(train_dataset)}")
|
| 266 |
+
if val_dataset:
|
| 267 |
+
click.echo(f" - Val samples: {len(val_dataset)}")
|
| 268 |
+
|
| 269 |
+
# Select trainer
|
| 270 |
+
if mode == TrainingModeEnum.PRETRAIN:
|
| 271 |
+
trainer_class = PretrainTrainer
|
| 272 |
+
elif mode == TrainingModeEnum.SFT:
|
| 273 |
+
trainer_class = SFTTrainer
|
| 274 |
+
elif mode == TrainingModeEnum.RL:
|
| 275 |
+
trainer_class = RLTrainer
|
| 276 |
+
else:
|
| 277 |
+
raise ValueError(f"Unknown training mode: {mode}")
|
| 278 |
+
|
| 279 |
+
# Create trainer
|
| 280 |
+
click.echo("Setting up trainer...")
|
| 281 |
+
trainer = trainer_class(
|
| 282 |
+
model=model,
|
| 283 |
+
train_dataset=train_dataset,
|
| 284 |
+
val_dataset=val_dataset,
|
| 285 |
+
config=train_config,
|
| 286 |
+
device=device,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Training loop
|
| 290 |
+
click.echo("\nStarting training...\n")
|
| 291 |
+
for epoch in range(train_config.num_epochs):
|
| 292 |
+
if train_config.max_steps and trainer.global_step >= train_config.max_steps:
|
| 293 |
+
break
|
| 294 |
+
|
| 295 |
+
epoch_metrics = trainer.train_epoch()
|
| 296 |
+
click.echo(f"\nEpoch {epoch + 1} complete")
|
| 297 |
+
click.echo(f" - Loss: {epoch_metrics.get('loss', 'N/A')}")
|
| 298 |
+
click.echo(f" - Learning rate: {epoch_metrics.get('lr', 'N/A')}")
|
| 299 |
+
|
| 300 |
+
# Final checkpoint
|
| 301 |
+
final_path = Path(train_config.checkpoint_dir) / "final_model.pt"
|
| 302 |
+
trainer.save_checkpoint(final_path)
|
| 303 |
+
click.echo(f"\nTraining complete! Final model saved to {final_path}")
|
| 304 |
+
|
| 305 |
+
# Log finish
|
| 306 |
+
trainer.logger.finish()
|
| 307 |
+
|
| 308 |
+
except Exception as e:
|
| 309 |
+
click.echo(f"Error during training: {e}", err=True)
|
| 310 |
+
sys.exit(1)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
@main.command()
|
| 314 |
+
@click.option(
|
| 315 |
+
"--model",
|
| 316 |
+
type=click.Path(exists=True),
|
| 317 |
+
required=True,
|
| 318 |
+
help="Path to model checkpoint",
|
| 319 |
+
)
|
| 320 |
+
@click.option(
|
| 321 |
+
"--benchmark-type",
|
| 322 |
+
type=click.Choice(["all", "perplexity", "throughput", "memory"]),
|
| 323 |
+
default="all",
|
| 324 |
+
help="Type of benchmark to run",
|
| 325 |
+
)
|
| 326 |
+
@click.option(
|
| 327 |
+
"--batch-size",
|
| 328 |
+
type=int,
|
| 329 |
+
default=32,
|
| 330 |
+
help="Batch size for benchmarking",
|
| 331 |
+
)
|
| 332 |
+
@click.option(
|
| 333 |
+
"--seq-length",
|
| 334 |
+
type=int,
|
| 335 |
+
default=1024,
|
| 336 |
+
help="Sequence length for benchmarking",
|
| 337 |
+
)
|
| 338 |
+
def benchmark(model: str, benchmark_type: str, batch_size: int, seq_length: int):
|
| 339 |
+
"""Benchmark a trained model."""
|
| 340 |
+
try:
|
| 341 |
+
click.echo(f"Loading model from {model}...")
|
| 342 |
+
device = get_device("cuda")
|
| 343 |
+
runner = BenchmarkRunner.load_from_checkpoint(model, device=device)
|
| 344 |
+
|
| 345 |
+
click.echo("Running benchmarks...\n")
|
| 346 |
+
|
| 347 |
+
if benchmark_type == "throughput" or benchmark_type == "all":
|
| 348 |
+
click.echo("Throughput benchmark:")
|
| 349 |
+
results = runner.benchmark_throughput(batch_size, seq_length)
|
| 350 |
+
for key, val in results.items():
|
| 351 |
+
click.echo(f" {key}: {val:.2f}")
|
| 352 |
+
|
| 353 |
+
if benchmark_type == "memory" or benchmark_type == "all":
|
| 354 |
+
click.echo("\nMemory benchmark:")
|
| 355 |
+
results = runner.benchmark_memory()
|
| 356 |
+
for key, val in results.items():
|
| 357 |
+
click.echo(f" {key}: {val:.2f}")
|
| 358 |
+
|
| 359 |
+
click.echo("\nBenchmarking complete!")
|
| 360 |
+
|
| 361 |
+
except Exception as e:
|
| 362 |
+
click.echo(f"Error during benchmarking: {e}", err=True)
|
| 363 |
+
sys.exit(1)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
@main.command()
|
| 367 |
+
@click.option("--repo", type=str, default=".aim", help="AimStack repository path")
|
| 368 |
+
def view_logs(repo: str):
|
| 369 |
+
"""View training logs with AimStack."""
|
| 370 |
+
try:
|
| 371 |
+
import subprocess
|
| 372 |
+
click.echo(f"Opening AimStack dashboard for repo: {repo}")
|
| 373 |
+
subprocess.run(["aim", "up", "--repo", repo])
|
| 374 |
+
except FileNotFoundError:
|
| 375 |
+
click.echo("Error: 'aim' command not found. Install with: pip install aim", err=True)
|
| 376 |
+
sys.exit(1)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def _count_params(model: BaseModel) -> tuple[int, int]:
|
| 380 |
+
"""Count model parameters."""
|
| 381 |
+
total = sum(p.numel() for p in model.parameters())
|
| 382 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 383 |
+
return total, trainable
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
if __name__ == "__main__":
|
| 387 |
+
main()
|
code/TaoTrain/src/taoTrain/config.py
ADDED
|
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic configuration schemas for TaoTrain."""
|
| 2 |
+
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing import Optional, Literal
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import json
|
| 7 |
+
from pydantic import BaseModel as PydanticBaseModel, Field, validator
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ============================================================================
|
| 12 |
+
# Enums
|
| 13 |
+
# ============================================================================
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DataTypeEnum(str, Enum):
|
| 17 |
+
"""Data types for training."""
|
| 18 |
+
FLOAT32 = "float32"
|
| 19 |
+
FLOAT16 = "float16"
|
| 20 |
+
BFLOAT16 = "bfloat16"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class OptimizerEnum(str, Enum):
|
| 24 |
+
"""Supported optimizers."""
|
| 25 |
+
ADAM = "adam"
|
| 26 |
+
ADAMW = "adamw"
|
| 27 |
+
SGD = "sgd"
|
| 28 |
+
HYBRID_MUON_ADAMW = "hybrid_muon_adamw"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ModelArchitectureEnum(str, Enum):
|
| 32 |
+
"""Built-in model architectures."""
|
| 33 |
+
TRANSFORMER = "transformer"
|
| 34 |
+
TAONET = "taonet"
|
| 35 |
+
TAONET_SSM = "taonet_ssm"
|
| 36 |
+
TAONET_HYBRID = "taonet_hybrid"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SchedulerEnum(str, Enum):
|
| 40 |
+
"""Supported learning rate schedulers."""
|
| 41 |
+
LINEAR_WARMUP = "linearWarmup"
|
| 42 |
+
COSINE_WARMUP = "cosineWarmup"
|
| 43 |
+
CONSTANT = "constant"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class RLMethodEnum(str, Enum):
|
| 47 |
+
"""Supported RL training methods."""
|
| 48 |
+
PPO = "ppo"
|
| 49 |
+
DPO = "dpo"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class TrainingModeEnum(str, Enum):
|
| 53 |
+
"""Training stages."""
|
| 54 |
+
PRETRAIN = "pretrain"
|
| 55 |
+
SFT = "sft"
|
| 56 |
+
RL = "rl"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ============================================================================
|
| 60 |
+
# Base Configs
|
| 61 |
+
# ============================================================================
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class BaseConfig(PydanticBaseModel):
|
| 65 |
+
"""Base Pydantic model with utility methods."""
|
| 66 |
+
|
| 67 |
+
class Config:
|
| 68 |
+
"""Pydantic config."""
|
| 69 |
+
arbitrary_types_allowed = True
|
| 70 |
+
|
| 71 |
+
def to_dict(self) -> dict:
|
| 72 |
+
"""Convert to dictionary."""
|
| 73 |
+
data = self.model_dump(mode='json') # Enums -> strings
|
| 74 |
+
return data
|
| 75 |
+
|
| 76 |
+
def to_json_str(self) -> str:
|
| 77 |
+
"""Convert to JSON string."""
|
| 78 |
+
return json.dumps(self.to_dict(), indent=2)
|
| 79 |
+
|
| 80 |
+
def save_yaml(self, path: str | Path) -> None:
|
| 81 |
+
"""Save config to YAML file."""
|
| 82 |
+
path = Path(path)
|
| 83 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 84 |
+
with open(path, 'w') as f:
|
| 85 |
+
yaml.dump(self.to_dict(), f, default_flow_style=False, sort_keys=False)
|
| 86 |
+
|
| 87 |
+
def save_json(self, path: str | Path) -> None:
|
| 88 |
+
"""Save config to JSON file."""
|
| 89 |
+
path = Path(path)
|
| 90 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 91 |
+
with open(path, 'w') as f:
|
| 92 |
+
f.write(self.to_json_str())
|
| 93 |
+
|
| 94 |
+
@classmethod
|
| 95 |
+
def load_yaml(cls, path: str | Path) -> "BaseConfig":
|
| 96 |
+
"""Load config from YAML file."""
|
| 97 |
+
with open(path) as f:
|
| 98 |
+
data = yaml.safe_load(f)
|
| 99 |
+
return cls(**data)
|
| 100 |
+
|
| 101 |
+
@classmethod
|
| 102 |
+
def load_json(cls, path: str | Path) -> "BaseConfig":
|
| 103 |
+
"""Load config from JSON file."""
|
| 104 |
+
with open(path) as f:
|
| 105 |
+
data = json.load(f)
|
| 106 |
+
return cls(**data)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ============================================================================
|
| 110 |
+
# Model Config
|
| 111 |
+
# ============================================================================
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class ModelConfig(BaseConfig):
|
| 115 |
+
"""Configuration for model architecture."""
|
| 116 |
+
|
| 117 |
+
architecture_type: ModelArchitectureEnum = Field(
|
| 118 |
+
default=ModelArchitectureEnum.TRANSFORMER,
|
| 119 |
+
description="Type of model architecture"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Transformer-specific
|
| 123 |
+
vocab_size: int = Field(default=50257, description="Vocabulary size")
|
| 124 |
+
hidden_dim: int = Field(default=768, description="Hidden dimension")
|
| 125 |
+
num_layers: int = Field(default=12, description="Number of transformer blocks")
|
| 126 |
+
num_heads: int = Field(default=12, description="Number of attention heads")
|
| 127 |
+
head_dim: Optional[int] = Field(
|
| 128 |
+
default=None,
|
| 129 |
+
description="Head dimension (defaults to hidden_dim // num_heads)"
|
| 130 |
+
)
|
| 131 |
+
intermediate_dim: Optional[int] = Field(
|
| 132 |
+
default=None,
|
| 133 |
+
description="FFN intermediate dimension (defaults to 4 * hidden_dim)"
|
| 134 |
+
)
|
| 135 |
+
dropout: float = Field(default=0.1, description="Dropout rate")
|
| 136 |
+
max_seq_length: int = Field(default=2048, description="Maximum sequence length")
|
| 137 |
+
|
| 138 |
+
# TaoNet (DeepSeek MLA) specific
|
| 139 |
+
d_latent_kv: Optional[int] = Field(
|
| 140 |
+
default=None,
|
| 141 |
+
description="KV compression dimension for MLA (defaults to 3/4 * hidden_dim). Only used for taonet architecture."
|
| 142 |
+
)
|
| 143 |
+
d_rope: Optional[int] = Field(
|
| 144 |
+
default=None,
|
| 145 |
+
description="RoPE dimension per head (defaults to hidden_dim // num_heads). Only used for taonet architecture."
|
| 146 |
+
)
|
| 147 |
+
gqa_groups: int = Field(
|
| 148 |
+
default=1,
|
| 149 |
+
description="Grouped Query Attention groups (1 = standard MLA, >1 = GQA). Only used for taonet architecture."
|
| 150 |
+
)
|
| 151 |
+
hidden_dim_ff: Optional[int] = Field(
|
| 152 |
+
default=None,
|
| 153 |
+
description="Feed-forward intermediate dimension (defaults to 4 * hidden_dim)."
|
| 154 |
+
)
|
| 155 |
+
use_factorized_embedding: bool = Field(
|
| 156 |
+
default=False,
|
| 157 |
+
description="Use low-rank factorized embedding instead of standard embedding (reduces params). Only for taonet."
|
| 158 |
+
)
|
| 159 |
+
d_embed_rank: int = Field(
|
| 160 |
+
default=96,
|
| 161 |
+
description="Rank dimension for factorized embedding. Only used if use_factorized_embedding=True."
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# YaRN (Yet another RoPE eXtension) for context length extension
|
| 165 |
+
rope_scale: float = Field(
|
| 166 |
+
default=40.0,
|
| 167 |
+
description="Base RoPE scale factor (default: 40.0). Controls position frequency base."
|
| 168 |
+
)
|
| 169 |
+
yarn_enabled: bool = Field(
|
| 170 |
+
default=False,
|
| 171 |
+
description="Enable YaRN (Yet another RoPE eXtension) for context length interpolation."
|
| 172 |
+
)
|
| 173 |
+
yarn_alpha: float = Field(
|
| 174 |
+
default=1.0,
|
| 175 |
+
description="YaRN interpolation smoothness (1.0=smooth, <1.0=aggressive, >1.0=conservative). Only used if yarn_enabled=True."
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# TaoNet-SSM specific: SSM mixer replacing MLA attention
|
| 179 |
+
ssm_core: Literal["gamma_s4", "dplr"] = Field(
|
| 180 |
+
default="gamma_s4",
|
| 181 |
+
description="SSM core used by taonet_ssm. Use dplr for the ternary-aware DPLR SSM."
|
| 182 |
+
)
|
| 183 |
+
ssm_hidden_dim: Optional[int] = Field(
|
| 184 |
+
default=None,
|
| 185 |
+
description="SSM hidden/state dimension for taonet_ssm. Defaults to d_latent_kv or hidden_dim."
|
| 186 |
+
)
|
| 187 |
+
ssm_mixer_dim: Optional[int] = Field(
|
| 188 |
+
default=None,
|
| 189 |
+
description="Channel dimension processed by the SSM mixer. Defaults to hidden_dim; smaller values use an input/output projection bottleneck."
|
| 190 |
+
)
|
| 191 |
+
ssm_num_lanes: int = Field(
|
| 192 |
+
default=1,
|
| 193 |
+
description="Number of independent SSM lanes inside each SSM mixer. Multiple lanes add SSM capacity with cheap elementwise combination."
|
| 194 |
+
)
|
| 195 |
+
ssm_lane_combine: Literal["mean", "channel"] = Field(
|
| 196 |
+
default="mean",
|
| 197 |
+
description="How to combine multiple SSM lanes. Channel uses learned per-lane/per-channel elementwise weights."
|
| 198 |
+
)
|
| 199 |
+
ssm_lane_mode: Literal["full", "split"] = Field(
|
| 200 |
+
default="full",
|
| 201 |
+
description="Whether each SSM lane processes the full mixer dimension or a disjoint split of the mixer channels."
|
| 202 |
+
)
|
| 203 |
+
ssm_split_mix: Literal["none", "hadamard"] = Field(
|
| 204 |
+
default="none",
|
| 205 |
+
description="Optional ternary-friendly cross-lane mixer for split SSM lanes."
|
| 206 |
+
)
|
| 207 |
+
ssm_rank: int = Field(
|
| 208 |
+
default=1,
|
| 209 |
+
description="Low-rank correction rank for ssm_core=dplr."
|
| 210 |
+
)
|
| 211 |
+
ssm_max_low_rank_scale: float = Field(
|
| 212 |
+
default=0.1,
|
| 213 |
+
description="Maximum low-rank correction scale for ssm_core=dplr."
|
| 214 |
+
)
|
| 215 |
+
ssm_finite_tail_correction: bool = Field(
|
| 216 |
+
default=True,
|
| 217 |
+
description="Enable exact finite-length tail correction for ssm_core=dplr. Disable for the faster approximate DPLR path."
|
| 218 |
+
)
|
| 219 |
+
ssm_discretization: Literal["bilinear", "zoh", "euler"] = Field(
|
| 220 |
+
default="bilinear",
|
| 221 |
+
description="Discretization used by the Gamma SSM mixer."
|
| 222 |
+
)
|
| 223 |
+
ssm_kernel_mode: Literal["auto", "recurrent", "conv", "conv_transfer"] = Field(
|
| 224 |
+
default="auto",
|
| 225 |
+
description="Gamma SSM execution path. Use auto/conv for full-sequence GPU training, conv_transfer to materialize frequency transfers, recurrent for step-wise tests."
|
| 226 |
+
)
|
| 227 |
+
ssm_kernel_threshold: int = Field(
|
| 228 |
+
default=64,
|
| 229 |
+
description="Minimum sequence length for auto mode to use the convolutional Gamma SSM path."
|
| 230 |
+
)
|
| 231 |
+
ssm_dt_min: float = Field(default=1e-3, description="Minimum learned SSM timestep.")
|
| 232 |
+
ssm_dt_max: float = Field(default=1e-1, description="Maximum learned SSM timestep.")
|
| 233 |
+
ssm_dt_init: float = Field(default=1e-2, description="Initial learned SSM timestep.")
|
| 234 |
+
ssm_use_d: bool = Field(default=True, description="Enable direct skip term D in the Gamma SSM.")
|
| 235 |
+
ssm_activation: Literal["gelu", "silu", "identity", "linear"] = Field(
|
| 236 |
+
default="gelu",
|
| 237 |
+
description="Activation applied to the Gamma SSM branch output."
|
| 238 |
+
)
|
| 239 |
+
ssm_gate: bool = Field(default=True, description="Enable output gate on the Gamma SSM branch.")
|
| 240 |
+
ssm_input_gate: bool = Field(default=True, description="Enable input gate before the Gamma SSM.")
|
| 241 |
+
ssm_gate_type: Literal["dense", "channel"] = Field(
|
| 242 |
+
default="dense",
|
| 243 |
+
description="Gate implementation for enabled SSM input/output gates. Channel gates are elementwise and ternary-friendly."
|
| 244 |
+
)
|
| 245 |
+
ssm_use_padding_mask: bool = Field(
|
| 246 |
+
default=False,
|
| 247 |
+
description="Apply dataset padding masks inside the SSM. Disabled by default so training can use the convolutional path."
|
| 248 |
+
)
|
| 249 |
+
ssm_layer_scale_init: float = Field(
|
| 250 |
+
default=0.1,
|
| 251 |
+
description="Initial layer-scale multiplier for the Gamma SSM branch."
|
| 252 |
+
)
|
| 253 |
+
ssm_branch_rms_norm: bool = Field(
|
| 254 |
+
default=False,
|
| 255 |
+
description="Normalize the SSM residual branch to unit RMS before layer-scale. Useful for stabilizing deep SSM/hybrid runs."
|
| 256 |
+
)
|
| 257 |
+
ssm_branch_rms_eps: float = Field(
|
| 258 |
+
default=1e-6,
|
| 259 |
+
description="Numerical epsilon for optional SSM branch RMS normalization."
|
| 260 |
+
)
|
| 261 |
+
ssm_branch_clip_value: Optional[float] = Field(
|
| 262 |
+
default=None,
|
| 263 |
+
description="Optional symmetric clamp applied to the SSM residual branch after layer-scale. None disables clamping."
|
| 264 |
+
)
|
| 265 |
+
block_residual_rms_norm: bool = Field(
|
| 266 |
+
default=False,
|
| 267 |
+
description="Normalize the residual stream RMS after block residual additions. Intended for stabilizing deep SSM/hybrid experiments."
|
| 268 |
+
)
|
| 269 |
+
block_residual_rms_target: float = Field(
|
| 270 |
+
default=1.0,
|
| 271 |
+
description="Target per-token RMS when block_residual_rms_norm is enabled."
|
| 272 |
+
)
|
| 273 |
+
block_residual_rms_cap: Optional[float] = Field(
|
| 274 |
+
default=None,
|
| 275 |
+
description="Optional per-token RMS cap for the residual stream. Unlike block_residual_rms_norm, this only scales down tokens whose RMS exceeds the cap."
|
| 276 |
+
)
|
| 277 |
+
block_residual_rms_eps: float = Field(
|
| 278 |
+
default=1e-6,
|
| 279 |
+
description="Numerical epsilon for optional block residual RMS normalization."
|
| 280 |
+
)
|
| 281 |
+
ssm_local_shift: bool = Field(
|
| 282 |
+
default=False,
|
| 283 |
+
description="Add a cheap one-token causal shift/register branch to the taonet_ssm mixer."
|
| 284 |
+
)
|
| 285 |
+
ssm_local_shift_init: float = Field(
|
| 286 |
+
default=0.1,
|
| 287 |
+
description="Initial scalar weight for the optional one-token local shift/register branch."
|
| 288 |
+
)
|
| 289 |
+
ssm_local_shift_per_channel: bool = Field(
|
| 290 |
+
default=False,
|
| 291 |
+
description="Use one learned local-shift gain per model channel instead of one scalar."
|
| 292 |
+
)
|
| 293 |
+
hybrid_pattern: Literal["attention_first", "ssm_first", "single_ssm_middle", "single_ssm_late"] = Field(
|
| 294 |
+
default="attention_first",
|
| 295 |
+
description="Layer pattern for taonet_hybrid when hybrid_ssm_layers is not set."
|
| 296 |
+
)
|
| 297 |
+
hybrid_ssm_layers: Optional[str] = Field(
|
| 298 |
+
default=None,
|
| 299 |
+
description="Optional comma-separated 0-based layer indices that should use SSM blocks in taonet_hybrid."
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Initializations
|
| 303 |
+
init_std: float = Field(default=0.02, description="Weight initialization standard deviation")
|
| 304 |
+
|
| 305 |
+
@validator("head_dim", always=True)
|
| 306 |
+
def validate_head_dim(cls, v, values):
|
| 307 |
+
"""Validate head dimension."""
|
| 308 |
+
if v is None and 'hidden_dim' in values:
|
| 309 |
+
return values['hidden_dim'] // values.get('num_heads', 12)
|
| 310 |
+
return v
|
| 311 |
+
|
| 312 |
+
@validator("intermediate_dim", always=True)
|
| 313 |
+
def validate_intermediate_dim(cls, v, values):
|
| 314 |
+
"""Validate intermediate dimension."""
|
| 315 |
+
if v is None and 'hidden_dim' in values:
|
| 316 |
+
return 4 * values['hidden_dim']
|
| 317 |
+
return v
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# ============================================================================
|
| 321 |
+
# Dataset Config
|
| 322 |
+
# ============================================================================
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class DatasetConfig(BaseConfig):
|
| 326 |
+
"""Configuration for dataset loading."""
|
| 327 |
+
|
| 328 |
+
# Local vs HuggingFace dataset selection
|
| 329 |
+
local: bool = Field(default=False, description="Use local JSONL dataset instead of HuggingFace")
|
| 330 |
+
|
| 331 |
+
# HuggingFace dataset fields
|
| 332 |
+
dataset_name: Optional[str] = Field(default=None, description="HuggingFace dataset name (e.g., 'wikitext', 'openwebtext')")
|
| 333 |
+
split: str = Field(default="train", description="Dataset split to use")
|
| 334 |
+
config: Optional[str] = Field(default=None, description="Dataset config if multi-config (e.g., 'wikitext-103')")
|
| 335 |
+
|
| 336 |
+
# Local JSONL dataset fields
|
| 337 |
+
jsonl_path: Optional[str] = Field(default=None, description="Path to local JSONL dataset file")
|
| 338 |
+
text_field: str = Field(default="text", description="Name of text field in JSONL")
|
| 339 |
+
|
| 340 |
+
# Text column name varies by dataset
|
| 341 |
+
text_column: str = Field(default="text", description="Name of text column in dataset")
|
| 342 |
+
|
| 343 |
+
# Preprocessing
|
| 344 |
+
max_samples: Optional[int] = Field(
|
| 345 |
+
default=None,
|
| 346 |
+
description="Limit dataset to N samples (useful for debugging)"
|
| 347 |
+
)
|
| 348 |
+
cache_dir: str = Field(default=".cache/datasets", description="HuggingFace cache directory")
|
| 349 |
+
|
| 350 |
+
# For SFT/RL datasets with instruction-response format
|
| 351 |
+
instruction_column: Optional[str] = Field(default=None, description="Instruction column for SFT")
|
| 352 |
+
response_column: Optional[str] = Field(default=None, description="Response column for SFT")
|
| 353 |
+
prompt_column: Optional[str] = Field(default=None, description="Prompt column for RL")
|
| 354 |
+
|
| 355 |
+
# Instruction template
|
| 356 |
+
instruction_template: Optional[str] = Field(
|
| 357 |
+
default=None,
|
| 358 |
+
description="Template for combining instruction and response. E.g., '{instruction}\\n{response}'"
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Tokenizer configuration
|
| 362 |
+
tokenizer_type: Optional[str] = Field(
|
| 363 |
+
default=None,
|
| 364 |
+
description="Tokenizer type: 'huggingface' or 'sentencepiece'. If None, defaults based on tokenizer_path."
|
| 365 |
+
)
|
| 366 |
+
tokenizer_path: Optional[str] = Field(
|
| 367 |
+
default=None,
|
| 368 |
+
description="Path to saved tokenizer (for SentencePiece: .model file, for HuggingFace: model name or local path)"
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# Chunked loading for large JSONL files
|
| 372 |
+
enable_streaming: bool = Field(
|
| 373 |
+
default=True,
|
| 374 |
+
description="Enable streaming/chunked loading for large JSONL files to reduce memory usage"
|
| 375 |
+
)
|
| 376 |
+
chunk_size_gb: float = Field(
|
| 377 |
+
default=5.0,
|
| 378 |
+
description="Approximate chunk size in GB (ignored if samples_per_chunk is set)"
|
| 379 |
+
)
|
| 380 |
+
samples_per_chunk: Optional[int] = Field(
|
| 381 |
+
default=1000,
|
| 382 |
+
description="Number of samples per chunk (takes precedence over chunk_size_gb). Default: 1000 samples"
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Chunk caching
|
| 386 |
+
enable_chunk_metadata_cache: bool = Field(
|
| 387 |
+
default=True,
|
| 388 |
+
description="Enable caching of chunk metadata (file scan results) to avoid re-scanning large JSONL files"
|
| 389 |
+
)
|
| 390 |
+
enable_chunk_data_cache: bool = Field(
|
| 391 |
+
default=False,
|
| 392 |
+
description="Enable caching of actual chunk data as separate files for faster loading (uses more disk space)"
|
| 393 |
+
)
|
| 394 |
+
chunk_cache_dir: str = Field(
|
| 395 |
+
default=".cache/chunks",
|
| 396 |
+
description="Directory to store chunk metadata and data cache files"
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
# Tokenization parallelization
|
| 400 |
+
tokenizer_threads: int = Field(
|
| 401 |
+
default=1,
|
| 402 |
+
description="Number of background threads for tokenization (1-32 recommended). Higher values speed up tokenization but increase memory usage."
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
@validator('jsonl_path', always=True)
|
| 406 |
+
def validate_dataset_source(cls, v, values):
|
| 407 |
+
"""Validate that either local JSONL or HuggingFace dataset is specified."""
|
| 408 |
+
local = values.get('local', False)
|
| 409 |
+
dataset_name = values.get('dataset_name')
|
| 410 |
+
|
| 411 |
+
if local and not v:
|
| 412 |
+
raise ValueError("jsonl_path must be provided when local=True")
|
| 413 |
+
if not local and not dataset_name:
|
| 414 |
+
raise ValueError("dataset_name must be provided when local=False (HuggingFace dataset)")
|
| 415 |
+
|
| 416 |
+
return v
|
| 417 |
+
|
| 418 |
+
@validator('tokenizer_threads')
|
| 419 |
+
def validate_tokenizer_threads(cls, v):
|
| 420 |
+
"""Validate tokenizer_threads is a positive integer."""
|
| 421 |
+
if v < 1:
|
| 422 |
+
raise ValueError("tokenizer_threads must be at least 1")
|
| 423 |
+
if v > 128:
|
| 424 |
+
raise ValueError("tokenizer_threads should not exceed 128 (recommended: 1-32)")
|
| 425 |
+
return v
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# ============================================================================
|
| 429 |
+
# Tokenizer Config
|
| 430 |
+
# ============================================================================
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class TokenizerConfig(BaseConfig):
|
| 434 |
+
"""Configuration for tokenizer training."""
|
| 435 |
+
|
| 436 |
+
# Dataset source
|
| 437 |
+
jsonl_path: str = Field(description="Path to JSONL file containing training data")
|
| 438 |
+
text_field: str = Field(default="text", description="Field name in JSONL for text data")
|
| 439 |
+
|
| 440 |
+
# Training configuration
|
| 441 |
+
vocab_size: int = Field(default=50000, description="Vocabulary size")
|
| 442 |
+
model_type: str = Field(default="unigram", description="SentencePiece model type (unigram, bpe, char, word)")
|
| 443 |
+
character_coverage: float = Field(
|
| 444 |
+
default=0.9995,
|
| 445 |
+
description="Character coverage for SentencePiece training"
|
| 446 |
+
)
|
| 447 |
+
output_dir: str = Field(default="tokenizers", description="Directory to save trained tokenizer")
|
| 448 |
+
tokenizer_prefix: Optional[str] = Field(
|
| 449 |
+
default=None,
|
| 450 |
+
description="Prefix for tokenizer output files (default: model_type)"
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# SentencePiece token IDs
|
| 454 |
+
unk_id: int = Field(default=0, description="Unknown token ID")
|
| 455 |
+
bos_id: int = Field(default=1, description="Beginning of sentence token ID")
|
| 456 |
+
eos_id: int = Field(default=2, description="End of sentence token ID")
|
| 457 |
+
pad_id: int = Field(default=3, description="Padding token ID")
|
| 458 |
+
|
| 459 |
+
# Custom special tokens - add custom tokens like <PAD>, <EOS>, <BOS>, <UNK>, <think>, <user>, <assistant>, <image>
|
| 460 |
+
special_tokens: Optional[dict[str, int]] = Field(
|
| 461 |
+
default=None,
|
| 462 |
+
description="Custom special tokens mapping: {token: id}. Example: {'<think>': 4, '<user>': 5, '<assistant>': 6, '<image>': 7}"
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
# Data sampling
|
| 466 |
+
max_samples: Optional[int] = Field(
|
| 467 |
+
default=None,
|
| 468 |
+
description="Limit training to first N samples from JSONL (useful for quick testing)"
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
# Tokenizer metadata
|
| 472 |
+
tokenizer_name: Optional[str] = Field(
|
| 473 |
+
default=None,
|
| 474 |
+
description="Optional name for the tokenizer"
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# ============================================================================
|
| 479 |
+
# Training Config
|
| 480 |
+
# ============================================================================
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class OptimizerConfig(BaseConfig):
|
| 484 |
+
"""Optimizer configuration."""
|
| 485 |
+
|
| 486 |
+
optimizer_type: OptimizerEnum = Field(default=OptimizerEnum.ADAMW, description="Optimizer type")
|
| 487 |
+
learning_rate: float = Field(default=1e-4, description="Peak learning rate (for Muon 2D weights)")
|
| 488 |
+
adamw_lr: Optional[float] = Field(
|
| 489 |
+
default=None,
|
| 490 |
+
description="Learning rate for AdamW (1D parameters). If None, defaults to learning_rate / 10. Used in hybrid_muon_adamw optimizer."
|
| 491 |
+
)
|
| 492 |
+
weight_decay: float = Field(default=1e-2, description="Weight decay (L2 regularization)")
|
| 493 |
+
betas: tuple[float, float] = Field(default=(0.9, 0.999), description="Adam betas")
|
| 494 |
+
eps: float = Field(default=1e-8, description="Optimizer epsilon")
|
| 495 |
+
|
| 496 |
+
@validator('adamw_lr', always=True)
|
| 497 |
+
def set_default_adamw_lr(cls, v, values):
|
| 498 |
+
"""Set default adamw_lr as 1/10 of learning_rate if not specified."""
|
| 499 |
+
if v is None and 'learning_rate' in values:
|
| 500 |
+
return values['learning_rate'] / 10
|
| 501 |
+
return v
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
class SchedulerConfig(BaseConfig):
|
| 505 |
+
"""Learning rate scheduler configuration."""
|
| 506 |
+
|
| 507 |
+
scheduler_type: SchedulerEnum = Field(default=SchedulerEnum.LINEAR_WARMUP, description="Scheduler type")
|
| 508 |
+
warmup_steps: int = Field(default=0, description="Number of warmup steps (takes precedence over warmup_ratio)")
|
| 509 |
+
warmup_ratio: float = Field(default=0.1, description="Warmup as fraction of total steps (used if warmup_steps=0)")
|
| 510 |
+
|
| 511 |
+
# Cosine scheduler specific
|
| 512 |
+
num_cycles: float = Field(default=0.5, description="Number of cycles for cosine schedule")
|
| 513 |
+
last_epoch: int = Field(default=-1, description="Last epoch for scheduler")
|
| 514 |
+
|
| 515 |
+
# TaoNet 3-phase scheduler (warmup -> steady -> cosine decay)
|
| 516 |
+
steady_ratio: float = Field(
|
| 517 |
+
default=0.0,
|
| 518 |
+
description="Fraction of training steps at peak LR before cosine decay (0.0 = no steady phase). Only for cosineWarmup."
|
| 519 |
+
)
|
| 520 |
+
min_lr_ratio: float = Field(
|
| 521 |
+
default=0.0,
|
| 522 |
+
description="Minimum LR as fraction of peak LR at end of training (0.0 = decay to 0). Only for cosineWarmup."
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
@validator('warmup_ratio')
|
| 526 |
+
def validate_warmup_ratio(cls, v):
|
| 527 |
+
"""Validate warmup ratio is between 0 and 1."""
|
| 528 |
+
if not 0 <= v <= 1:
|
| 529 |
+
raise ValueError("warmup_ratio must be between 0 and 1")
|
| 530 |
+
return v
|
| 531 |
+
|
| 532 |
+
@validator('steady_ratio')
|
| 533 |
+
def validate_steady_ratio(cls, v):
|
| 534 |
+
"""Validate steady ratio is between 0 and 1."""
|
| 535 |
+
if not 0 <= v <= 1:
|
| 536 |
+
raise ValueError("steady_ratio must be between 0 and 1")
|
| 537 |
+
return v
|
| 538 |
+
|
| 539 |
+
@validator('min_lr_ratio')
|
| 540 |
+
def validate_min_lr_ratio(cls, v):
|
| 541 |
+
"""Validate min_lr_ratio is between 0 and 1."""
|
| 542 |
+
if not 0 <= v <= 1:
|
| 543 |
+
raise ValueError("min_lr_ratio must be between 0 and 1")
|
| 544 |
+
return v
|
| 545 |
+
|
| 546 |
+
@validator('warmup_steps')
|
| 547 |
+
def validate_warmup_steps(cls, v):
|
| 548 |
+
"""Validate warmup steps is non-negative."""
|
| 549 |
+
if v < 0:
|
| 550 |
+
raise ValueError("warmup_steps must be non-negative")
|
| 551 |
+
return v
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
class TrainingConfig(BaseConfig):
|
| 555 |
+
"""Base training configuration shared across all modes."""
|
| 556 |
+
|
| 557 |
+
# Data and model
|
| 558 |
+
model: ModelConfig = Field(default_factory=ModelConfig, description="Model configuration")
|
| 559 |
+
dataset: DatasetConfig = Field(description="Dataset configuration")
|
| 560 |
+
|
| 561 |
+
# Training hyperparameters
|
| 562 |
+
batch_size: int = Field(default=32, description="Batch size per device")
|
| 563 |
+
num_epochs: int = Field(default=3, description="Number of training epochs")
|
| 564 |
+
max_steps: Optional[int] = Field(
|
| 565 |
+
default=None,
|
| 566 |
+
description="Maximum steps (overrides num_epochs if set)"
|
| 567 |
+
)
|
| 568 |
+
gradient_accumulation_steps: int = Field(
|
| 569 |
+
default=1,
|
| 570 |
+
description="Gradient accumulation steps"
|
| 571 |
+
)
|
| 572 |
+
max_grad_norm: float = Field(default=1.0, description="Gradient clipping max norm")
|
| 573 |
+
|
| 574 |
+
# Optimizer
|
| 575 |
+
optimizer: OptimizerConfig = Field(
|
| 576 |
+
default_factory=OptimizerConfig,
|
| 577 |
+
description="Optimizer configuration"
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
# Scheduler
|
| 581 |
+
scheduler: SchedulerConfig = Field(
|
| 582 |
+
default_factory=SchedulerConfig,
|
| 583 |
+
description="Learning rate scheduler configuration"
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# Data type and device
|
| 587 |
+
dtype: DataTypeEnum = Field(
|
| 588 |
+
default=DataTypeEnum.BFLOAT16,
|
| 589 |
+
description="Training data type"
|
| 590 |
+
)
|
| 591 |
+
device: str = Field(default="cuda", description="Device to train on (cuda, cpu)")
|
| 592 |
+
seed: int = Field(default=42, description="Random seed")
|
| 593 |
+
|
| 594 |
+
# Checkpointing
|
| 595 |
+
checkpoint_dir: str = Field(default="checkpoints", description="Directory to save checkpoints")
|
| 596 |
+
checkpoint_path: Optional[str] = Field(
|
| 597 |
+
default=None,
|
| 598 |
+
description="Path to load pretrained checkpoint (for SFT/RL). If provided, loads weights before training starts."
|
| 599 |
+
)
|
| 600 |
+
save_every_steps: int = Field(default=500, description="Save checkpoint every N steps")
|
| 601 |
+
keep_last_n_checkpoints: int = Field(default=3, description="Keep only last N checkpoints")
|
| 602 |
+
save_best_model: bool = Field(default=True, description="Save best model based on validation loss")
|
| 603 |
+
|
| 604 |
+
# Validation
|
| 605 |
+
eval_every_steps: int = Field(default=500, description="Evaluate every N steps")
|
| 606 |
+
eval_samples: int = Field(default=1000, description="Number of validation samples")
|
| 607 |
+
|
| 608 |
+
# Logging
|
| 609 |
+
log_every_steps: int = Field(default=10, description="Log metrics every N steps")
|
| 610 |
+
aim_repo: str = Field(default=".aim", description="AimStack repository path")
|
| 611 |
+
|
| 612 |
+
# Misc
|
| 613 |
+
num_workers: int = Field(default=0, description="Number of DataLoader workers")
|
| 614 |
+
pin_memory: bool = Field(default=True, description="Pin memory for DataLoader")
|
| 615 |
+
use_compile: bool = Field(default=False, description="Use torch.compile (experimental)")
|
| 616 |
+
|
| 617 |
+
# Mode
|
| 618 |
+
mode: TrainingModeEnum = Field(default=TrainingModeEnum.PRETRAIN, description="Training mode")
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
# ============================================================================
|
| 622 |
+
# Stage-Specific Configs
|
| 623 |
+
# ============================================================================
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
class PretrainConfig(TrainingConfig):
|
| 627 |
+
"""Configuration for pretraining."""
|
| 628 |
+
|
| 629 |
+
mode: Literal[TrainingModeEnum.PRETRAIN] = TrainingModeEnum.PRETRAIN
|
| 630 |
+
|
| 631 |
+
# Pretraining-specific
|
| 632 |
+
sequence_length: int = Field(default=1024, description="Sequence length for pretraining")
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class SFTConfig(TrainingConfig):
|
| 636 |
+
"""Configuration for supervised fine-tuning."""
|
| 637 |
+
|
| 638 |
+
mode: Literal[TrainingModeEnum.SFT] = TrainingModeEnum.SFT
|
| 639 |
+
|
| 640 |
+
# SFT-specific
|
| 641 |
+
response_loss_only: bool = Field(
|
| 642 |
+
default=True,
|
| 643 |
+
description="Only compute loss on response/assistant tokens (not instruction/user tokens). Uses -100 label masking."
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
# Multi-turn conversation role tokens
|
| 647 |
+
user_token: str = Field(
|
| 648 |
+
default="<user>",
|
| 649 |
+
description="Special token representing user/instruction role in conversations"
|
| 650 |
+
)
|
| 651 |
+
assistant_token: str = Field(
|
| 652 |
+
default="<assistant>",
|
| 653 |
+
description="Special token representing assistant/response role in conversations"
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
class RLConfig(TrainingConfig):
|
| 658 |
+
"""Configuration for reinforcement learning training."""
|
| 659 |
+
|
| 660 |
+
mode: Literal[TrainingModeEnum.RL] = TrainingModeEnum.RL
|
| 661 |
+
|
| 662 |
+
# RL-specific
|
| 663 |
+
rl_method: RLMethodEnum = Field(
|
| 664 |
+
default=RLMethodEnum.PPO,
|
| 665 |
+
description="RL training method (PPO or DPO)"
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
# Reward model
|
| 669 |
+
reward_model_path: str = Field(description="Path to trained reward model checkpoint")
|
| 670 |
+
|
| 671 |
+
# PPO-specific
|
| 672 |
+
ppo_epochs: int = Field(default=4, description="PPO inner epochs")
|
| 673 |
+
ppo_clip_ratio: float = Field(default=0.2, description="PPO clipping ratio")
|
| 674 |
+
entropy_coeff: float = Field(default=0.01, description="Entropy bonus coefficient")
|
| 675 |
+
value_loss_coeff: float = Field(default=1.0, description="Value function loss coefficient")
|
| 676 |
+
|
| 677 |
+
# DPO-specific (Direct Preference Optimization)
|
| 678 |
+
dpo_beta: float = Field(default=0.1, description="DPO inverse temperature (beta)")
|
| 679 |
+
|
| 680 |
+
# Prompt distribution
|
| 681 |
+
prompt_dataset: Optional[DatasetConfig] = Field(
|
| 682 |
+
default=None,
|
| 683 |
+
description="Separate dataset for prompts (if different from main dataset)"
|
| 684 |
+
)
|
| 685 |
+
generation_max_length: int = Field(
|
| 686 |
+
default=256,
|
| 687 |
+
description="Maximum length for generated responses during RL"
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
# ============================================================================
|
| 692 |
+
# Factory function
|
| 693 |
+
# ============================================================================
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def load_config(path: str | Path, mode: TrainingModeEnum | str) -> TrainingConfig:
|
| 697 |
+
"""Load config file and return appropriate config class."""
|
| 698 |
+
if isinstance(mode, str):
|
| 699 |
+
mode = TrainingModeEnum(mode)
|
| 700 |
+
|
| 701 |
+
config_map = {
|
| 702 |
+
TrainingModeEnum.PRETRAIN: PretrainConfig,
|
| 703 |
+
TrainingModeEnum.SFT: SFTConfig,
|
| 704 |
+
TrainingModeEnum.RL: RLConfig,
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
config_class = config_map[mode]
|
| 708 |
+
|
| 709 |
+
path = Path(path)
|
| 710 |
+
if path.suffix == '.yaml' or path.suffix == '.yml':
|
| 711 |
+
return config_class.load_yaml(path)
|
| 712 |
+
elif path.suffix == '.json':
|
| 713 |
+
return config_class.load_json(path)
|
| 714 |
+
else:
|
| 715 |
+
raise ValueError(f"Unsupported config file format: {path.suffix}")
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def load_tokenizer_config(path: str | Path) -> TokenizerConfig:
|
| 719 |
+
"""Load tokenizer config from YAML or JSON file."""
|
| 720 |
+
path = Path(path)
|
| 721 |
+
if path.suffix == '.yaml' or path.suffix == '.yml':
|
| 722 |
+
return TokenizerConfig.load_yaml(path)
|
| 723 |
+
elif path.suffix == '.json':
|
| 724 |
+
return TokenizerConfig.load_json(path)
|
| 725 |
+
else:
|
| 726 |
+
raise ValueError(f"Unsupported config file format: {path.suffix}")
|