Add files using upload-large-folder tool
Browse files- LTA_openwebtext_dualt/logs/lta_lm1b_classic_dirichlet_len512_gbs512_4gpu_10k_save1k_20260523.train.pid +1 -0
- LTA_openwebtext_dualt/logs/noise_geometry_combo_4gpu/20260517_170456.log +994 -0
- LTA_openwebtext_dualt/logs/train8_len_sweep_compact_bs512_until_exact_4gpu/driver.log +0 -0
- LTA_openwebtext_dualt/scripts/apple_to_apple_lta_checks.py +631 -0
- LTA_openwebtext_dualt/scripts/build_lta_owt_compact_gpt2bpe_stream1024_train_minus_100k_np8.sh +13 -0
- LTA_openwebtext_dualt/scripts/build_owt_t5_elf_dataset.py +587 -0
- LTA_openwebtext_dualt/scripts/eval_dirichlet_latest_key3_state_20260508.py +51 -0
- LTA_openwebtext_dualt/scripts/infer_lta_owt_t5_len128_uniform10k_then_lognsr_latest.sh +113 -0
- LTA_openwebtext_dualt/scripts/launch_lta_lm1b_categorical_fullvocab_c1024_fullycoupled_8gpu_small_1m.sh +150 -0
- LTA_openwebtext_dualt/scripts/launch_lta_lm1b_categorical_fullvocab_c16_dualt_4gpu_small_1m.sh +155 -0
- LTA_openwebtext_dualt/scripts/launch_lta_owt_c1024_fullycoupled_8gpu_len1024_gpt2_cached_chunks_1m.sh +60 -0
- LTA_openwebtext_dualt/scripts/launch_lta_owt_compact_gpt2bpe_v8192_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu.sh +39 -0
- LTA_openwebtext_dualt/scripts/launch_lta_owt_elfaligned_t5_logitnormal_8gpu.sh +209 -0
- LTA_openwebtext_dualt/scripts/launch_lta_owt_fullycoupled_outwd0p5_8gpu.sh +11 -0
- LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_rollin_grad_k1_rho025_subset10k_4gpu_100k.sh +148 -0
- LTA_openwebtext_dualt/scripts/run_lta_lm1b_dirichlet_len1024_Cv_to_2v_8gpu_1m_save10k.sh +34 -0
- LTA_openwebtext_dualt/scripts/run_lta_owt_dirichlet_len1024_Cv_to_2v_8gpu_1m_save10k.sh +34 -0
- LTA_openwebtext_dualt/scripts/run_lta_owt_t5_absrope_adaln_dirichlet_len1024_Cv_to_2v_8gpu_mask0p1_1p0_sameT_1m_save10k.sh +36 -0
- LTA_openwebtext_dualt/scripts/run_train8_wrong_floor_pilots_4gpu.sh +194 -0
- LTA_openwebtext_dualt/scripts/watch_infer_owt_classic_fullvocab_len1024_lr2e4_gbs2048_latest_every1k_t1p45.sh +158 -0
LTA_openwebtext_dualt/logs/lta_lm1b_classic_dirichlet_len512_gbs512_4gpu_10k_save1k_20260523.train.pid
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
993819
|
LTA_openwebtext_dualt/logs/noise_geometry_combo_4gpu/20260517_170456.log
ADDED
|
@@ -0,0 +1,994 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[combo-pilot] start stamp=20260517_170456 len=256 vocab=969 out=docs/lta_samples/metrics_20260517/noise_geometry_combo_len256_bs512_ode128_20260517_170456
|
| 2 |
+
[combo-pilot] round=1 Sun May 17 17:04:56 UTC 2026
|
| 3 |
+
[combo-pilot] train config=logistic_unigram_shared_highC from=0 to=1000 sampler=logistic_normal_linear_mean C=64->4096 unigram_shared=0.5 seq=0.0
|
| 4 |
+
[combo-pilot] eval config=logistic_unigram_shared_highC step=1000
|
| 5 |
+
[eval-decode-acc] train8_combo_len256_logistic_unigram_shared_highC_20260517_170456 step=1000 soft=none
|
| 6 |
+
[decode] max_len=256 generated=64/64
|
| 7 |
+
{
|
| 8 |
+
"num_rows": 1,
|
| 9 |
+
"best_by_run": {
|
| 10 |
+
"train8_combo_len256_logistic_unigram_shared_highC_20260517_170456::none": {
|
| 11 |
+
"run": "train8_combo_len256_logistic_unigram_shared_highC_20260517_170456",
|
| 12 |
+
"checkpoint": "runs/train8_combo_len256_logistic_unigram_shared_highC_20260517_170456/step_0001000.pt",
|
| 13 |
+
"ckpt_step": 1000,
|
| 14 |
+
"endpoint_softening": "none",
|
| 15 |
+
"decode_rule": "flowmap",
|
| 16 |
+
"steps": 128,
|
| 17 |
+
"time_schedule": "logit_normal",
|
| 18 |
+
"model_t_mode": "post",
|
| 19 |
+
"final_from": "state",
|
| 20 |
+
"n_gen": 64,
|
| 21 |
+
"n_refs": 8,
|
| 22 |
+
"token_acc_mean": 0.0487060546875,
|
| 23 |
+
"token_acc_min": 0.03515625,
|
| 24 |
+
"token_acc_max": 0.07421875,
|
| 25 |
+
"exact_acc": 0.0,
|
| 26 |
+
"exact_count": 0,
|
| 27 |
+
"exact_ref_coverage": 0.0,
|
| 28 |
+
"exact_ref_count": 0,
|
| 29 |
+
"exact_ref_hits": [],
|
| 30 |
+
"best_ref_idx": [
|
| 31 |
+
5,
|
| 32 |
+
0,
|
| 33 |
+
0,
|
| 34 |
+
0,
|
| 35 |
+
5,
|
| 36 |
+
5,
|
| 37 |
+
5,
|
| 38 |
+
0,
|
| 39 |
+
5,
|
| 40 |
+
2,
|
| 41 |
+
1,
|
| 42 |
+
0,
|
| 43 |
+
7,
|
| 44 |
+
2,
|
| 45 |
+
7,
|
| 46 |
+
0,
|
| 47 |
+
3,
|
| 48 |
+
3,
|
| 49 |
+
2,
|
| 50 |
+
0,
|
| 51 |
+
2,
|
| 52 |
+
2,
|
| 53 |
+
5,
|
| 54 |
+
7,
|
| 55 |
+
5,
|
| 56 |
+
7,
|
| 57 |
+
7,
|
| 58 |
+
2,
|
| 59 |
+
5,
|
| 60 |
+
7,
|
| 61 |
+
5,
|
| 62 |
+
2,
|
| 63 |
+
1,
|
| 64 |
+
5,
|
| 65 |
+
0,
|
| 66 |
+
0,
|
| 67 |
+
5,
|
| 68 |
+
2,
|
| 69 |
+
0,
|
| 70 |
+
0,
|
| 71 |
+
2,
|
| 72 |
+
0,
|
| 73 |
+
0,
|
| 74 |
+
5,
|
| 75 |
+
5,
|
| 76 |
+
3,
|
| 77 |
+
5,
|
| 78 |
+
5,
|
| 79 |
+
5,
|
| 80 |
+
3,
|
| 81 |
+
3,
|
| 82 |
+
0,
|
| 83 |
+
3,
|
| 84 |
+
2,
|
| 85 |
+
5,
|
| 86 |
+
0,
|
| 87 |
+
7,
|
| 88 |
+
0,
|
| 89 |
+
1,
|
| 90 |
+
5,
|
| 91 |
+
2,
|
| 92 |
+
7,
|
| 93 |
+
3,
|
| 94 |
+
2
|
| 95 |
+
],
|
| 96 |
+
"best_token_acc": [
|
| 97 |
+
0.04296875,
|
| 98 |
+
0.04296875,
|
| 99 |
+
0.04296875,
|
| 100 |
+
0.046875,
|
| 101 |
+
0.05859375,
|
| 102 |
+
0.04296875,
|
| 103 |
+
0.04296875,
|
| 104 |
+
0.05859375,
|
| 105 |
+
0.046875,
|
| 106 |
+
0.05859375,
|
| 107 |
+
0.04296875,
|
| 108 |
+
0.05859375,
|
| 109 |
+
0.0390625,
|
| 110 |
+
0.046875,
|
| 111 |
+
0.0625,
|
| 112 |
+
0.0390625,
|
| 113 |
+
0.04296875,
|
| 114 |
+
0.046875,
|
| 115 |
+
0.046875,
|
| 116 |
+
0.046875,
|
| 117 |
+
0.05078125,
|
| 118 |
+
0.05078125,
|
| 119 |
+
0.04296875,
|
| 120 |
+
0.0546875,
|
| 121 |
+
0.046875,
|
| 122 |
+
0.046875,
|
| 123 |
+
0.046875,
|
| 124 |
+
0.046875,
|
| 125 |
+
0.0625,
|
| 126 |
+
0.0625,
|
| 127 |
+
0.05078125,
|
| 128 |
+
0.0390625,
|
| 129 |
+
0.0546875,
|
| 130 |
+
0.046875,
|
| 131 |
+
0.04296875,
|
| 132 |
+
0.0390625,
|
| 133 |
+
0.05078125,
|
| 134 |
+
0.0390625,
|
| 135 |
+
0.046875,
|
| 136 |
+
0.04296875,
|
| 137 |
+
0.03515625,
|
| 138 |
+
0.046875,
|
| 139 |
+
0.046875,
|
| 140 |
+
0.0546875,
|
| 141 |
+
0.0546875,
|
| 142 |
+
0.04296875,
|
| 143 |
+
0.04296875,
|
| 144 |
+
0.0546875,
|
| 145 |
+
0.04296875,
|
| 146 |
+
0.046875,
|
| 147 |
+
0.05078125,
|
| 148 |
+
0.07421875,
|
| 149 |
+
0.04296875,
|
| 150 |
+
0.05078125,
|
| 151 |
+
0.046875,
|
| 152 |
+
0.0546875,
|
| 153 |
+
0.0546875,
|
| 154 |
+
0.04296875,
|
| 155 |
+
0.0546875,
|
| 156 |
+
0.0546875,
|
| 157 |
+
0.0546875,
|
| 158 |
+
0.05078125,
|
| 159 |
+
0.04296875,
|
| 160 |
+
0.05078125
|
| 161 |
+
]
|
| 162 |
+
}
|
| 163 |
+
},
|
| 164 |
+
"first_exact_by_run": {}
|
| 165 |
+
}
|
| 166 |
+
RESULT config=logistic_unigram_shared_highC ckpt_step=1000 views=512000 token_acc=0.0487 exact=0/64 exact_refs=0 hits=[]
|
| 167 |
+
[combo-pilot] continue config=logistic_unigram_shared_highC step=1000
|
| 168 |
+
[combo-pilot] train config=logistic_unigram_shared_highC_seqrand from=0 to=1000 sampler=logistic_normal_linear_mean C=64->4096 unigram_shared=0.5 seq=0.5
|
| 169 |
+
[combo-pilot] eval config=logistic_unigram_shared_highC_seqrand step=1000
|
| 170 |
+
[eval-decode-acc] train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456 step=1000 soft=none
|
| 171 |
+
[decode] max_len=256 generated=64/64
|
| 172 |
+
{
|
| 173 |
+
"num_rows": 1,
|
| 174 |
+
"best_by_run": {
|
| 175 |
+
"train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456::none": {
|
| 176 |
+
"run": "train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456",
|
| 177 |
+
"checkpoint": "runs/train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456/step_0001000.pt",
|
| 178 |
+
"ckpt_step": 1000,
|
| 179 |
+
"endpoint_softening": "none",
|
| 180 |
+
"decode_rule": "flowmap",
|
| 181 |
+
"steps": 128,
|
| 182 |
+
"time_schedule": "logit_normal",
|
| 183 |
+
"model_t_mode": "post",
|
| 184 |
+
"final_from": "state",
|
| 185 |
+
"n_gen": 64,
|
| 186 |
+
"n_refs": 8,
|
| 187 |
+
"token_acc_mean": 0.04034423828125,
|
| 188 |
+
"token_acc_min": 0.0234375,
|
| 189 |
+
"token_acc_max": 0.0625,
|
| 190 |
+
"exact_acc": 0.0,
|
| 191 |
+
"exact_count": 0,
|
| 192 |
+
"exact_ref_coverage": 0.0,
|
| 193 |
+
"exact_ref_count": 0,
|
| 194 |
+
"exact_ref_hits": [],
|
| 195 |
+
"best_ref_idx": [
|
| 196 |
+
0,
|
| 197 |
+
0,
|
| 198 |
+
0,
|
| 199 |
+
0,
|
| 200 |
+
0,
|
| 201 |
+
3,
|
| 202 |
+
0,
|
| 203 |
+
7,
|
| 204 |
+
0,
|
| 205 |
+
4,
|
| 206 |
+
0,
|
| 207 |
+
5,
|
| 208 |
+
4,
|
| 209 |
+
0,
|
| 210 |
+
0,
|
| 211 |
+
0,
|
| 212 |
+
3,
|
| 213 |
+
0,
|
| 214 |
+
0,
|
| 215 |
+
0,
|
| 216 |
+
3,
|
| 217 |
+
0,
|
| 218 |
+
0,
|
| 219 |
+
0,
|
| 220 |
+
4,
|
| 221 |
+
0,
|
| 222 |
+
0,
|
| 223 |
+
5,
|
| 224 |
+
0,
|
| 225 |
+
4,
|
| 226 |
+
0,
|
| 227 |
+
0,
|
| 228 |
+
0,
|
| 229 |
+
0,
|
| 230 |
+
5,
|
| 231 |
+
0,
|
| 232 |
+
0,
|
| 233 |
+
0,
|
| 234 |
+
0,
|
| 235 |
+
4,
|
| 236 |
+
0,
|
| 237 |
+
0,
|
| 238 |
+
0,
|
| 239 |
+
5,
|
| 240 |
+
3,
|
| 241 |
+
0,
|
| 242 |
+
0,
|
| 243 |
+
0,
|
| 244 |
+
0,
|
| 245 |
+
4,
|
| 246 |
+
0,
|
| 247 |
+
4,
|
| 248 |
+
0,
|
| 249 |
+
0,
|
| 250 |
+
0,
|
| 251 |
+
0,
|
| 252 |
+
5,
|
| 253 |
+
0,
|
| 254 |
+
0,
|
| 255 |
+
0,
|
| 256 |
+
4,
|
| 257 |
+
0,
|
| 258 |
+
3,
|
| 259 |
+
0
|
| 260 |
+
],
|
| 261 |
+
"best_token_acc": [
|
| 262 |
+
0.03515625,
|
| 263 |
+
0.03515625,
|
| 264 |
+
0.03125,
|
| 265 |
+
0.05859375,
|
| 266 |
+
0.03515625,
|
| 267 |
+
0.0234375,
|
| 268 |
+
0.03515625,
|
| 269 |
+
0.02734375,
|
| 270 |
+
0.0625,
|
| 271 |
+
0.03515625,
|
| 272 |
+
0.02734375,
|
| 273 |
+
0.03125,
|
| 274 |
+
0.0234375,
|
| 275 |
+
0.03515625,
|
| 276 |
+
0.046875,
|
| 277 |
+
0.04296875,
|
| 278 |
+
0.05078125,
|
| 279 |
+
0.03125,
|
| 280 |
+
0.03515625,
|
| 281 |
+
0.0625,
|
| 282 |
+
0.03125,
|
| 283 |
+
0.04296875,
|
| 284 |
+
0.02734375,
|
| 285 |
+
0.04296875,
|
| 286 |
+
0.03125,
|
| 287 |
+
0.0390625,
|
| 288 |
+
0.05078125,
|
| 289 |
+
0.0390625,
|
| 290 |
+
0.02734375,
|
| 291 |
+
0.03125,
|
| 292 |
+
0.03125,
|
| 293 |
+
0.0234375,
|
| 294 |
+
0.046875,
|
| 295 |
+
0.05078125,
|
| 296 |
+
0.04296875,
|
| 297 |
+
0.03515625,
|
| 298 |
+
0.05078125,
|
| 299 |
+
0.04296875,
|
| 300 |
+
0.0390625,
|
| 301 |
+
0.05078125,
|
| 302 |
+
0.0390625,
|
| 303 |
+
0.046875,
|
| 304 |
+
0.0390625,
|
| 305 |
+
0.0390625,
|
| 306 |
+
0.02734375,
|
| 307 |
+
0.05078125,
|
| 308 |
+
0.05078125,
|
| 309 |
+
0.046875,
|
| 310 |
+
0.04296875,
|
| 311 |
+
0.046875,
|
| 312 |
+
0.05859375,
|
| 313 |
+
0.05859375,
|
| 314 |
+
0.04296875,
|
| 315 |
+
0.05078125,
|
| 316 |
+
0.05078125,
|
| 317 |
+
0.046875,
|
| 318 |
+
0.03125,
|
| 319 |
+
0.04296875,
|
| 320 |
+
0.0390625,
|
| 321 |
+
0.05078125,
|
| 322 |
+
0.03125,
|
| 323 |
+
0.03125,
|
| 324 |
+
0.03515625,
|
| 325 |
+
0.0390625
|
| 326 |
+
]
|
| 327 |
+
}
|
| 328 |
+
},
|
| 329 |
+
"first_exact_by_run": {}
|
| 330 |
+
}
|
| 331 |
+
RESULT config=logistic_unigram_shared_highC_seqrand ckpt_step=1000 views=512000 token_acc=0.0403 exact=0/64 exact_refs=0 hits=[]
|
| 332 |
+
[combo-pilot] continue config=logistic_unigram_shared_highC_seqrand step=1000
|
| 333 |
+
[combo-pilot] train config=logistic_unigram_shared_C1024 from=0 to=1000 sampler=logistic_normal_linear_mean C=1.0->1024 unigram_shared=0.5 seq=0.0
|
| 334 |
+
[combo-pilot] eval config=logistic_unigram_shared_C1024 step=1000
|
| 335 |
+
[eval-decode-acc] train8_combo_len256_logistic_unigram_shared_C1024_20260517_170456 step=1000 soft=none
|
| 336 |
+
[decode] max_len=256 generated=64/64
|
| 337 |
+
{
|
| 338 |
+
"num_rows": 1,
|
| 339 |
+
"best_by_run": {
|
| 340 |
+
"train8_combo_len256_logistic_unigram_shared_C1024_20260517_170456::none": {
|
| 341 |
+
"run": "train8_combo_len256_logistic_unigram_shared_C1024_20260517_170456",
|
| 342 |
+
"checkpoint": "runs/train8_combo_len256_logistic_unigram_shared_C1024_20260517_170456/step_0001000.pt",
|
| 343 |
+
"ckpt_step": 1000,
|
| 344 |
+
"endpoint_softening": "none",
|
| 345 |
+
"decode_rule": "flowmap",
|
| 346 |
+
"steps": 128,
|
| 347 |
+
"time_schedule": "logit_normal",
|
| 348 |
+
"model_t_mode": "post",
|
| 349 |
+
"final_from": "state",
|
| 350 |
+
"n_gen": 64,
|
| 351 |
+
"n_refs": 8,
|
| 352 |
+
"token_acc_mean": 0.0487060546875,
|
| 353 |
+
"token_acc_min": 0.03515625,
|
| 354 |
+
"token_acc_max": 0.07421875,
|
| 355 |
+
"exact_acc": 0.0,
|
| 356 |
+
"exact_count": 0,
|
| 357 |
+
"exact_ref_coverage": 0.0,
|
| 358 |
+
"exact_ref_count": 0,
|
| 359 |
+
"exact_ref_hits": [],
|
| 360 |
+
"best_ref_idx": [
|
| 361 |
+
5,
|
| 362 |
+
0,
|
| 363 |
+
0,
|
| 364 |
+
0,
|
| 365 |
+
5,
|
| 366 |
+
5,
|
| 367 |
+
5,
|
| 368 |
+
0,
|
| 369 |
+
5,
|
| 370 |
+
2,
|
| 371 |
+
1,
|
| 372 |
+
0,
|
| 373 |
+
7,
|
| 374 |
+
2,
|
| 375 |
+
7,
|
| 376 |
+
0,
|
| 377 |
+
3,
|
| 378 |
+
3,
|
| 379 |
+
2,
|
| 380 |
+
0,
|
| 381 |
+
2,
|
| 382 |
+
2,
|
| 383 |
+
5,
|
| 384 |
+
7,
|
| 385 |
+
5,
|
| 386 |
+
7,
|
| 387 |
+
7,
|
| 388 |
+
2,
|
| 389 |
+
5,
|
| 390 |
+
7,
|
| 391 |
+
5,
|
| 392 |
+
2,
|
| 393 |
+
1,
|
| 394 |
+
5,
|
| 395 |
+
0,
|
| 396 |
+
0,
|
| 397 |
+
5,
|
| 398 |
+
2,
|
| 399 |
+
0,
|
| 400 |
+
0,
|
| 401 |
+
2,
|
| 402 |
+
0,
|
| 403 |
+
0,
|
| 404 |
+
5,
|
| 405 |
+
5,
|
| 406 |
+
3,
|
| 407 |
+
5,
|
| 408 |
+
5,
|
| 409 |
+
5,
|
| 410 |
+
3,
|
| 411 |
+
3,
|
| 412 |
+
0,
|
| 413 |
+
3,
|
| 414 |
+
2,
|
| 415 |
+
5,
|
| 416 |
+
0,
|
| 417 |
+
7,
|
| 418 |
+
0,
|
| 419 |
+
1,
|
| 420 |
+
5,
|
| 421 |
+
2,
|
| 422 |
+
7,
|
| 423 |
+
3,
|
| 424 |
+
2
|
| 425 |
+
],
|
| 426 |
+
"best_token_acc": [
|
| 427 |
+
0.04296875,
|
| 428 |
+
0.04296875,
|
| 429 |
+
0.04296875,
|
| 430 |
+
0.046875,
|
| 431 |
+
0.05859375,
|
| 432 |
+
0.04296875,
|
| 433 |
+
0.04296875,
|
| 434 |
+
0.05859375,
|
| 435 |
+
0.046875,
|
| 436 |
+
0.05859375,
|
| 437 |
+
0.04296875,
|
| 438 |
+
0.05859375,
|
| 439 |
+
0.0390625,
|
| 440 |
+
0.046875,
|
| 441 |
+
0.0625,
|
| 442 |
+
0.0390625,
|
| 443 |
+
0.04296875,
|
| 444 |
+
0.046875,
|
| 445 |
+
0.046875,
|
| 446 |
+
0.046875,
|
| 447 |
+
0.05078125,
|
| 448 |
+
0.05078125,
|
| 449 |
+
0.04296875,
|
| 450 |
+
0.0546875,
|
| 451 |
+
0.046875,
|
| 452 |
+
0.046875,
|
| 453 |
+
0.046875,
|
| 454 |
+
0.046875,
|
| 455 |
+
0.0625,
|
| 456 |
+
0.0625,
|
| 457 |
+
0.05078125,
|
| 458 |
+
0.0390625,
|
| 459 |
+
0.0546875,
|
| 460 |
+
0.046875,
|
| 461 |
+
0.04296875,
|
| 462 |
+
0.0390625,
|
| 463 |
+
0.05078125,
|
| 464 |
+
0.0390625,
|
| 465 |
+
0.046875,
|
| 466 |
+
0.04296875,
|
| 467 |
+
0.03515625,
|
| 468 |
+
0.046875,
|
| 469 |
+
0.046875,
|
| 470 |
+
0.0546875,
|
| 471 |
+
0.0546875,
|
| 472 |
+
0.04296875,
|
| 473 |
+
0.04296875,
|
| 474 |
+
0.0546875,
|
| 475 |
+
0.04296875,
|
| 476 |
+
0.046875,
|
| 477 |
+
0.05078125,
|
| 478 |
+
0.07421875,
|
| 479 |
+
0.04296875,
|
| 480 |
+
0.05078125,
|
| 481 |
+
0.046875,
|
| 482 |
+
0.0546875,
|
| 483 |
+
0.0546875,
|
| 484 |
+
0.04296875,
|
| 485 |
+
0.0546875,
|
| 486 |
+
0.0546875,
|
| 487 |
+
0.0546875,
|
| 488 |
+
0.05078125,
|
| 489 |
+
0.04296875,
|
| 490 |
+
0.05078125
|
| 491 |
+
]
|
| 492 |
+
}
|
| 493 |
+
},
|
| 494 |
+
"first_exact_by_run": {}
|
| 495 |
+
}
|
| 496 |
+
RESULT config=logistic_unigram_shared_C1024 ckpt_step=1000 views=512000 token_acc=0.0487 exact=0/64 exact_refs=0 hits=[]
|
| 497 |
+
[combo-pilot] continue config=logistic_unigram_shared_C1024 step=1000
|
| 498 |
+
[combo-pilot] train config=dirichlet_unigram_shared_highC from=0 to=1000 sampler=dirichlet C=64->4096 unigram_shared=0.5 seq=0.0
|
| 499 |
+
[combo-pilot] eval config=dirichlet_unigram_shared_highC step=1000
|
| 500 |
+
[eval-decode-acc] train8_combo_len256_dirichlet_unigram_shared_highC_20260517_170456 step=1000 soft=none
|
| 501 |
+
[decode] max_len=256 generated=64/64
|
| 502 |
+
{
|
| 503 |
+
"num_rows": 1,
|
| 504 |
+
"best_by_run": {
|
| 505 |
+
"train8_combo_len256_dirichlet_unigram_shared_highC_20260517_170456::none": {
|
| 506 |
+
"run": "train8_combo_len256_dirichlet_unigram_shared_highC_20260517_170456",
|
| 507 |
+
"checkpoint": "runs/train8_combo_len256_dirichlet_unigram_shared_highC_20260517_170456/step_0001000.pt",
|
| 508 |
+
"ckpt_step": 1000,
|
| 509 |
+
"endpoint_softening": "none",
|
| 510 |
+
"decode_rule": "flowmap",
|
| 511 |
+
"steps": 128,
|
| 512 |
+
"time_schedule": "logit_normal",
|
| 513 |
+
"model_t_mode": "post",
|
| 514 |
+
"final_from": "state",
|
| 515 |
+
"n_gen": 64,
|
| 516 |
+
"n_refs": 8,
|
| 517 |
+
"token_acc_mean": 0.03857421875,
|
| 518 |
+
"token_acc_min": 0.02734375,
|
| 519 |
+
"token_acc_max": 0.05078125,
|
| 520 |
+
"exact_acc": 0.0,
|
| 521 |
+
"exact_count": 0,
|
| 522 |
+
"exact_ref_coverage": 0.0,
|
| 523 |
+
"exact_ref_count": 0,
|
| 524 |
+
"exact_ref_hits": [],
|
| 525 |
+
"best_ref_idx": [
|
| 526 |
+
1,
|
| 527 |
+
1,
|
| 528 |
+
1,
|
| 529 |
+
2,
|
| 530 |
+
1,
|
| 531 |
+
1,
|
| 532 |
+
0,
|
| 533 |
+
1,
|
| 534 |
+
0,
|
| 535 |
+
1,
|
| 536 |
+
0,
|
| 537 |
+
0,
|
| 538 |
+
1,
|
| 539 |
+
1,
|
| 540 |
+
1,
|
| 541 |
+
1,
|
| 542 |
+
1,
|
| 543 |
+
1,
|
| 544 |
+
1,
|
| 545 |
+
2,
|
| 546 |
+
0,
|
| 547 |
+
1,
|
| 548 |
+
1,
|
| 549 |
+
2,
|
| 550 |
+
1,
|
| 551 |
+
1,
|
| 552 |
+
0,
|
| 553 |
+
0,
|
| 554 |
+
1,
|
| 555 |
+
0,
|
| 556 |
+
2,
|
| 557 |
+
1,
|
| 558 |
+
1,
|
| 559 |
+
0,
|
| 560 |
+
0,
|
| 561 |
+
1,
|
| 562 |
+
0,
|
| 563 |
+
2,
|
| 564 |
+
0,
|
| 565 |
+
1,
|
| 566 |
+
1,
|
| 567 |
+
1,
|
| 568 |
+
1,
|
| 569 |
+
1,
|
| 570 |
+
0,
|
| 571 |
+
5,
|
| 572 |
+
2,
|
| 573 |
+
1,
|
| 574 |
+
0,
|
| 575 |
+
2,
|
| 576 |
+
1,
|
| 577 |
+
1,
|
| 578 |
+
1,
|
| 579 |
+
2,
|
| 580 |
+
1,
|
| 581 |
+
0,
|
| 582 |
+
1,
|
| 583 |
+
1,
|
| 584 |
+
1,
|
| 585 |
+
1,
|
| 586 |
+
1,
|
| 587 |
+
1,
|
| 588 |
+
0,
|
| 589 |
+
1
|
| 590 |
+
],
|
| 591 |
+
"best_token_acc": [
|
| 592 |
+
0.03125,
|
| 593 |
+
0.04296875,
|
| 594 |
+
0.046875,
|
| 595 |
+
0.0390625,
|
| 596 |
+
0.0390625,
|
| 597 |
+
0.0390625,
|
| 598 |
+
0.04296875,
|
| 599 |
+
0.03515625,
|
| 600 |
+
0.0390625,
|
| 601 |
+
0.03515625,
|
| 602 |
+
0.03125,
|
| 603 |
+
0.02734375,
|
| 604 |
+
0.03515625,
|
| 605 |
+
0.03125,
|
| 606 |
+
0.03515625,
|
| 607 |
+
0.03515625,
|
| 608 |
+
0.03515625,
|
| 609 |
+
0.04296875,
|
| 610 |
+
0.04296875,
|
| 611 |
+
0.03125,
|
| 612 |
+
0.02734375,
|
| 613 |
+
0.03125,
|
| 614 |
+
0.04296875,
|
| 615 |
+
0.0390625,
|
| 616 |
+
0.0390625,
|
| 617 |
+
0.03515625,
|
| 618 |
+
0.03515625,
|
| 619 |
+
0.0390625,
|
| 620 |
+
0.046875,
|
| 621 |
+
0.03515625,
|
| 622 |
+
0.05078125,
|
| 623 |
+
0.0390625,
|
| 624 |
+
0.046875,
|
| 625 |
+
0.04296875,
|
| 626 |
+
0.0390625,
|
| 627 |
+
0.0390625,
|
| 628 |
+
0.0390625,
|
| 629 |
+
0.04296875,
|
| 630 |
+
0.03125,
|
| 631 |
+
0.046875,
|
| 632 |
+
0.03515625,
|
| 633 |
+
0.046875,
|
| 634 |
+
0.046875,
|
| 635 |
+
0.04296875,
|
| 636 |
+
0.03125,
|
| 637 |
+
0.03515625,
|
| 638 |
+
0.03515625,
|
| 639 |
+
0.0390625,
|
| 640 |
+
0.03125,
|
| 641 |
+
0.046875,
|
| 642 |
+
0.0390625,
|
| 643 |
+
0.05078125,
|
| 644 |
+
0.0390625,
|
| 645 |
+
0.02734375,
|
| 646 |
+
0.02734375,
|
| 647 |
+
0.0390625,
|
| 648 |
+
0.05078125,
|
| 649 |
+
0.03125,
|
| 650 |
+
0.03515625,
|
| 651 |
+
0.04296875,
|
| 652 |
+
0.0390625,
|
| 653 |
+
0.04296875,
|
| 654 |
+
0.0390625,
|
| 655 |
+
0.046875
|
| 656 |
+
]
|
| 657 |
+
}
|
| 658 |
+
},
|
| 659 |
+
"first_exact_by_run": {}
|
| 660 |
+
}
|
| 661 |
+
RESULT config=dirichlet_unigram_shared_highC ckpt_step=1000 views=512000 token_acc=0.0386 exact=0/64 exact_refs=0 hits=[]
|
| 662 |
+
[combo-pilot] continue config=dirichlet_unigram_shared_highC step=1000
|
| 663 |
+
[combo-pilot] round=2 Sun May 17 17:08:26 UTC 2026
|
| 664 |
+
[combo-pilot] train config=logistic_unigram_shared_highC from=1000 to=2000 sampler=logistic_normal_linear_mean C=64->4096 unigram_shared=0.5 seq=0.0
|
| 665 |
+
[combo-pilot] eval config=logistic_unigram_shared_highC step=2000
|
| 666 |
+
[eval-decode-acc] train8_combo_len256_logistic_unigram_shared_highC_20260517_170456 step=2000 soft=none
|
| 667 |
+
[decode] max_len=256 generated=64/64
|
| 668 |
+
{
|
| 669 |
+
"num_rows": 1,
|
| 670 |
+
"best_by_run": {
|
| 671 |
+
"train8_combo_len256_logistic_unigram_shared_highC_20260517_170456::none": {
|
| 672 |
+
"run": "train8_combo_len256_logistic_unigram_shared_highC_20260517_170456",
|
| 673 |
+
"checkpoint": "runs/train8_combo_len256_logistic_unigram_shared_highC_20260517_170456/step_0002000.pt",
|
| 674 |
+
"ckpt_step": 2000,
|
| 675 |
+
"endpoint_softening": "none",
|
| 676 |
+
"decode_rule": "flowmap",
|
| 677 |
+
"steps": 128,
|
| 678 |
+
"time_schedule": "logit_normal",
|
| 679 |
+
"model_t_mode": "post",
|
| 680 |
+
"final_from": "state",
|
| 681 |
+
"n_gen": 64,
|
| 682 |
+
"n_refs": 8,
|
| 683 |
+
"token_acc_mean": 0.03033447265625,
|
| 684 |
+
"token_acc_min": 0.015625,
|
| 685 |
+
"token_acc_max": 0.046875,
|
| 686 |
+
"exact_acc": 0.0,
|
| 687 |
+
"exact_count": 0,
|
| 688 |
+
"exact_ref_coverage": 0.0,
|
| 689 |
+
"exact_ref_count": 0,
|
| 690 |
+
"exact_ref_hits": [],
|
| 691 |
+
"best_ref_idx": [
|
| 692 |
+
1,
|
| 693 |
+
1,
|
| 694 |
+
1,
|
| 695 |
+
1,
|
| 696 |
+
7,
|
| 697 |
+
3,
|
| 698 |
+
1,
|
| 699 |
+
7,
|
| 700 |
+
0,
|
| 701 |
+
0,
|
| 702 |
+
0,
|
| 703 |
+
3,
|
| 704 |
+
1,
|
| 705 |
+
0,
|
| 706 |
+
1,
|
| 707 |
+
5,
|
| 708 |
+
0,
|
| 709 |
+
0,
|
| 710 |
+
3,
|
| 711 |
+
0,
|
| 712 |
+
0,
|
| 713 |
+
1,
|
| 714 |
+
0,
|
| 715 |
+
7,
|
| 716 |
+
7,
|
| 717 |
+
1,
|
| 718 |
+
7,
|
| 719 |
+
0,
|
| 720 |
+
1,
|
| 721 |
+
0,
|
| 722 |
+
7,
|
| 723 |
+
1,
|
| 724 |
+
0,
|
| 725 |
+
0,
|
| 726 |
+
0,
|
| 727 |
+
0,
|
| 728 |
+
0,
|
| 729 |
+
3,
|
| 730 |
+
1,
|
| 731 |
+
0,
|
| 732 |
+
0,
|
| 733 |
+
1,
|
| 734 |
+
7,
|
| 735 |
+
5,
|
| 736 |
+
1,
|
| 737 |
+
0,
|
| 738 |
+
1,
|
| 739 |
+
1,
|
| 740 |
+
1,
|
| 741 |
+
0,
|
| 742 |
+
0,
|
| 743 |
+
0,
|
| 744 |
+
0,
|
| 745 |
+
0,
|
| 746 |
+
3,
|
| 747 |
+
0,
|
| 748 |
+
1,
|
| 749 |
+
7,
|
| 750 |
+
7,
|
| 751 |
+
0,
|
| 752 |
+
7,
|
| 753 |
+
0,
|
| 754 |
+
7,
|
| 755 |
+
5
|
| 756 |
+
],
|
| 757 |
+
"best_token_acc": [
|
| 758 |
+
0.03125,
|
| 759 |
+
0.02734375,
|
| 760 |
+
0.02734375,
|
| 761 |
+
0.03515625,
|
| 762 |
+
0.046875,
|
| 763 |
+
0.02734375,
|
| 764 |
+
0.03125,
|
| 765 |
+
0.04296875,
|
| 766 |
+
0.04296875,
|
| 767 |
+
0.02734375,
|
| 768 |
+
0.046875,
|
| 769 |
+
0.03515625,
|
| 770 |
+
0.02734375,
|
| 771 |
+
0.0234375,
|
| 772 |
+
0.01953125,
|
| 773 |
+
0.02734375,
|
| 774 |
+
0.02734375,
|
| 775 |
+
0.0390625,
|
| 776 |
+
0.02734375,
|
| 777 |
+
0.01953125,
|
| 778 |
+
0.03125,
|
| 779 |
+
0.03125,
|
| 780 |
+
0.01953125,
|
| 781 |
+
0.0390625,
|
| 782 |
+
0.0234375,
|
| 783 |
+
0.03125,
|
| 784 |
+
0.02734375,
|
| 785 |
+
0.02734375,
|
| 786 |
+
0.03125,
|
| 787 |
+
0.03125,
|
| 788 |
+
0.03125,
|
| 789 |
+
0.02734375,
|
| 790 |
+
0.03125,
|
| 791 |
+
0.03515625,
|
| 792 |
+
0.03125,
|
| 793 |
+
0.02734375,
|
| 794 |
+
0.03515625,
|
| 795 |
+
0.02734375,
|
| 796 |
+
0.0234375,
|
| 797 |
+
0.02734375,
|
| 798 |
+
0.03125,
|
| 799 |
+
0.03125,
|
| 800 |
+
0.03515625,
|
| 801 |
+
0.03515625,
|
| 802 |
+
0.02734375,
|
| 803 |
+
0.01953125,
|
| 804 |
+
0.0234375,
|
| 805 |
+
0.0234375,
|
| 806 |
+
0.015625,
|
| 807 |
+
0.046875,
|
| 808 |
+
0.03125,
|
| 809 |
+
0.02734375,
|
| 810 |
+
0.03515625,
|
| 811 |
+
0.0234375,
|
| 812 |
+
0.03125,
|
| 813 |
+
0.02734375,
|
| 814 |
+
0.0234375,
|
| 815 |
+
0.02734375,
|
| 816 |
+
0.03125,
|
| 817 |
+
0.03515625,
|
| 818 |
+
0.03515625,
|
| 819 |
+
0.03125,
|
| 820 |
+
0.03125,
|
| 821 |
+
0.0390625
|
| 822 |
+
]
|
| 823 |
+
}
|
| 824 |
+
},
|
| 825 |
+
"first_exact_by_run": {}
|
| 826 |
+
}
|
| 827 |
+
RESULT config=logistic_unigram_shared_highC ckpt_step=2000 views=1024000 token_acc=0.0303 exact=0/64 exact_refs=0 hits=[]
|
| 828 |
+
[combo-pilot] continue config=logistic_unigram_shared_highC step=2000
|
| 829 |
+
[combo-pilot] train config=logistic_unigram_shared_highC_seqrand from=1000 to=2000 sampler=logistic_normal_linear_mean C=64->4096 unigram_shared=0.5 seq=0.5
|
| 830 |
+
[combo-pilot] eval config=logistic_unigram_shared_highC_seqrand step=2000
|
| 831 |
+
[eval-decode-acc] train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456 step=2000 soft=none
|
| 832 |
+
[decode] max_len=256 generated=64/64
|
| 833 |
+
{
|
| 834 |
+
"num_rows": 1,
|
| 835 |
+
"best_by_run": {
|
| 836 |
+
"train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456::none": {
|
| 837 |
+
"run": "train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456",
|
| 838 |
+
"checkpoint": "runs/train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456/step_0002000.pt",
|
| 839 |
+
"ckpt_step": 2000,
|
| 840 |
+
"endpoint_softening": "none",
|
| 841 |
+
"decode_rule": "flowmap",
|
| 842 |
+
"steps": 128,
|
| 843 |
+
"time_schedule": "logit_normal",
|
| 844 |
+
"model_t_mode": "post",
|
| 845 |
+
"final_from": "state",
|
| 846 |
+
"n_gen": 64,
|
| 847 |
+
"n_refs": 8,
|
| 848 |
+
"token_acc_mean": 0.04046630859375,
|
| 849 |
+
"token_acc_min": 0.01953125,
|
| 850 |
+
"token_acc_max": 0.06640625,
|
| 851 |
+
"exact_acc": 0.0,
|
| 852 |
+
"exact_count": 0,
|
| 853 |
+
"exact_ref_coverage": 0.0,
|
| 854 |
+
"exact_ref_count": 0,
|
| 855 |
+
"exact_ref_hits": [],
|
| 856 |
+
"best_ref_idx": [
|
| 857 |
+
0,
|
| 858 |
+
7,
|
| 859 |
+
0,
|
| 860 |
+
0,
|
| 861 |
+
0,
|
| 862 |
+
0,
|
| 863 |
+
7,
|
| 864 |
+
0,
|
| 865 |
+
7,
|
| 866 |
+
7,
|
| 867 |
+
7,
|
| 868 |
+
0,
|
| 869 |
+
7,
|
| 870 |
+
7,
|
| 871 |
+
7,
|
| 872 |
+
7,
|
| 873 |
+
7,
|
| 874 |
+
7,
|
| 875 |
+
7,
|
| 876 |
+
0,
|
| 877 |
+
0,
|
| 878 |
+
7,
|
| 879 |
+
0,
|
| 880 |
+
0,
|
| 881 |
+
0,
|
| 882 |
+
7,
|
| 883 |
+
0,
|
| 884 |
+
7,
|
| 885 |
+
0,
|
| 886 |
+
0,
|
| 887 |
+
0,
|
| 888 |
+
0,
|
| 889 |
+
7,
|
| 890 |
+
1,
|
| 891 |
+
0,
|
| 892 |
+
7,
|
| 893 |
+
0,
|
| 894 |
+
0,
|
| 895 |
+
5,
|
| 896 |
+
0,
|
| 897 |
+
0,
|
| 898 |
+
7,
|
| 899 |
+
0,
|
| 900 |
+
0,
|
| 901 |
+
0,
|
| 902 |
+
7,
|
| 903 |
+
5,
|
| 904 |
+
0,
|
| 905 |
+
5,
|
| 906 |
+
2,
|
| 907 |
+
0,
|
| 908 |
+
0,
|
| 909 |
+
0,
|
| 910 |
+
7,
|
| 911 |
+
0,
|
| 912 |
+
7,
|
| 913 |
+
1,
|
| 914 |
+
0,
|
| 915 |
+
0,
|
| 916 |
+
0,
|
| 917 |
+
7,
|
| 918 |
+
2,
|
| 919 |
+
0,
|
| 920 |
+
0
|
| 921 |
+
],
|
| 922 |
+
"best_token_acc": [
|
| 923 |
+
0.01953125,
|
| 924 |
+
0.03125,
|
| 925 |
+
0.03515625,
|
| 926 |
+
0.0546875,
|
| 927 |
+
0.0390625,
|
| 928 |
+
0.0546875,
|
| 929 |
+
0.0234375,
|
| 930 |
+
0.03125,
|
| 931 |
+
0.046875,
|
| 932 |
+
0.05078125,
|
| 933 |
+
0.0390625,
|
| 934 |
+
0.0234375,
|
| 935 |
+
0.0390625,
|
| 936 |
+
0.05859375,
|
| 937 |
+
0.02734375,
|
| 938 |
+
0.02734375,
|
| 939 |
+
0.0546875,
|
| 940 |
+
0.05078125,
|
| 941 |
+
0.03515625,
|
| 942 |
+
0.046875,
|
| 943 |
+
0.05859375,
|
| 944 |
+
0.02734375,
|
| 945 |
+
0.046875,
|
| 946 |
+
0.04296875,
|
| 947 |
+
0.0546875,
|
| 948 |
+
0.01953125,
|
| 949 |
+
0.046875,
|
| 950 |
+
0.03125,
|
| 951 |
+
0.05078125,
|
| 952 |
+
0.05859375,
|
| 953 |
+
0.04296875,
|
| 954 |
+
0.01953125,
|
| 955 |
+
0.05078125,
|
| 956 |
+
0.02734375,
|
| 957 |
+
0.046875,
|
| 958 |
+
0.03515625,
|
| 959 |
+
0.03515625,
|
| 960 |
+
0.05859375,
|
| 961 |
+
0.03125,
|
| 962 |
+
0.04296875,
|
| 963 |
+
0.046875,
|
| 964 |
+
0.05078125,
|
| 965 |
+
0.04296875,
|
| 966 |
+
0.0546875,
|
| 967 |
+
0.02734375,
|
| 968 |
+
0.02734375,
|
| 969 |
+
0.02734375,
|
| 970 |
+
0.046875,
|
| 971 |
+
0.01953125,
|
| 972 |
+
0.03515625,
|
| 973 |
+
0.06640625,
|
| 974 |
+
0.03515625,
|
| 975 |
+
0.046875,
|
| 976 |
+
0.046875,
|
| 977 |
+
0.05078125,
|
| 978 |
+
0.03125,
|
| 979 |
+
0.03125,
|
| 980 |
+
0.03125,
|
| 981 |
+
0.03125,
|
| 982 |
+
0.0546875,
|
| 983 |
+
0.0546875,
|
| 984 |
+
0.02734375,
|
| 985 |
+
0.0546875,
|
| 986 |
+
0.03125
|
| 987 |
+
]
|
| 988 |
+
}
|
| 989 |
+
},
|
| 990 |
+
"first_exact_by_run": {}
|
| 991 |
+
}
|
| 992 |
+
RESULT config=logistic_unigram_shared_highC_seqrand ckpt_step=2000 views=1024000 token_acc=0.0405 exact=0/64 exact_refs=0 hits=[]
|
| 993 |
+
[combo-pilot] continue config=logistic_unigram_shared_highC_seqrand step=2000
|
| 994 |
+
[combo-pilot] train config=logistic_unigram_shared_C1024 from=1000 to=2000 sampler=logistic_normal_linear_mean C=1.0->1024 unigram_shared=0.5 seq=0.0
|
LTA_openwebtext_dualt/logs/train8_len_sweep_compact_bs512_until_exact_4gpu/driver.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LTA_openwebtext_dualt/scripts/apple_to_apple_lta_checks.py
ADDED
|
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import csv
|
| 6 |
+
import json
|
| 7 |
+
import math
|
| 8 |
+
import sys
|
| 9 |
+
from collections import Counter
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Iterable
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 19 |
+
if str(REPO_ROOT) not in sys.path:
|
| 20 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 21 |
+
|
| 22 |
+
from eval import build_model_from_ckpt
|
| 23 |
+
from flowtext_lab.bridges import make_dirichlet_bridge_batch
|
| 24 |
+
from flowtext_lab.data import EosPadCollator, WrappedStreamingTextSequenceDataset, iter_text_records
|
| 25 |
+
from flowtext_lab.decode import sample_noise_simplex, state_for_model
|
| 26 |
+
from flowtext_lab.tokenization import BpeTextTokenizer
|
| 27 |
+
from train import TokenizedTextCollator, load_tokenized_hf_dataset
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def token_piece(tok: BpeTextTokenizer, idx: int) -> str:
|
| 31 |
+
raw = getattr(tok, "tokenizer", None)
|
| 32 |
+
id_to_token = getattr(raw, "id_to_token", None)
|
| 33 |
+
if callable(id_to_token):
|
| 34 |
+
piece = id_to_token(int(idx))
|
| 35 |
+
if piece is not None:
|
| 36 |
+
return str(piece)
|
| 37 |
+
return tok.decode([int(idx)], stop_at_eos=False, skip_special_tokens=False)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def token_text(tok: BpeTextTokenizer, idx: int) -> str:
|
| 41 |
+
return tok.decode([int(idx)], stop_at_eos=False, skip_special_tokens=False)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def compact_piece(s: str) -> str:
|
| 45 |
+
return s.replace("\n", "\\n").replace("\t", "\\t")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def load_batch(
|
| 49 |
+
*,
|
| 50 |
+
data_path: str,
|
| 51 |
+
tokenizer: BpeTextTokenizer,
|
| 52 |
+
max_len: int,
|
| 53 |
+
batch_size: int,
|
| 54 |
+
mode: str,
|
| 55 |
+
text_column: str | None,
|
| 56 |
+
openwebtext_split: str,
|
| 57 |
+
wrap_mode: str,
|
| 58 |
+
max_records: int,
|
| 59 |
+
tokenized_pad_token: str,
|
| 60 |
+
) -> dict[str, torch.Tensor]:
|
| 61 |
+
if mode == "tokenized_hf":
|
| 62 |
+
ds = load_tokenized_hf_dataset(data_path, max_records=max_records)
|
| 63 |
+
pad_id = tokenizer.pad_id if tokenized_pad_token == "pad" and tokenizer.pad_id is not None else tokenizer.eos_id
|
| 64 |
+
collate = TokenizedTextCollator(pad_id, max_len=max_len)
|
| 65 |
+
examples = [ds[i] for i in range(min(batch_size, len(ds)))]
|
| 66 |
+
return collate(examples)
|
| 67 |
+
if mode != "wrap":
|
| 68 |
+
raise ValueError(f"unknown data mode: {mode}")
|
| 69 |
+
ds = WrappedStreamingTextSequenceDataset(
|
| 70 |
+
data_path,
|
| 71 |
+
tokenizer,
|
| 72 |
+
max_len=max_len,
|
| 73 |
+
text_column=text_column,
|
| 74 |
+
openwebtext_split=openwebtext_split,
|
| 75 |
+
max_records_per_epoch=max_records,
|
| 76 |
+
wrap_mode=wrap_mode,
|
| 77 |
+
)
|
| 78 |
+
loader = DataLoader(ds, batch_size=batch_size, collate_fn=EosPadCollator(tokenizer.eos_id, max_len=max_len))
|
| 79 |
+
return next(iter(loader))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def iter_record_lengths(
|
| 83 |
+
*,
|
| 84 |
+
data_path: str,
|
| 85 |
+
tokenizer: BpeTextTokenizer,
|
| 86 |
+
mode: str,
|
| 87 |
+
text_column: str | None,
|
| 88 |
+
openwebtext_split: str,
|
| 89 |
+
max_records: int,
|
| 90 |
+
) -> Iterable[int]:
|
| 91 |
+
if mode == "tokenized_hf":
|
| 92 |
+
ds = load_tokenized_hf_dataset(data_path, max_records=max_records)
|
| 93 |
+
for ex in ds:
|
| 94 |
+
raw = ex["input_ids"]
|
| 95 |
+
if hasattr(raw, "tolist"):
|
| 96 |
+
raw = raw.tolist()
|
| 97 |
+
yield len(raw)
|
| 98 |
+
return
|
| 99 |
+
for i, text in enumerate(
|
| 100 |
+
iter_text_records(
|
| 101 |
+
data_path,
|
| 102 |
+
text_column=text_column,
|
| 103 |
+
openwebtext_split=openwebtext_split,
|
| 104 |
+
detokenizer="auto",
|
| 105 |
+
)
|
| 106 |
+
):
|
| 107 |
+
if i >= max_records:
|
| 108 |
+
break
|
| 109 |
+
ids = tokenizer.encode(text, add_eos=False, add_special_tokens=False)
|
| 110 |
+
yield len(ids)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def rate_summary(values: list[float]) -> dict[str, float]:
|
| 114 |
+
if not values:
|
| 115 |
+
return {"mean": 0.0, "min": 0.0, "p50": 0.0, "p90": 0.0, "p99": 0.0, "max": 0.0}
|
| 116 |
+
vals = sorted(float(x) for x in values)
|
| 117 |
+
n = len(vals)
|
| 118 |
+
|
| 119 |
+
def q(p: float) -> float:
|
| 120 |
+
return vals[min(n - 1, max(0, int(round((n - 1) * p))))]
|
| 121 |
+
|
| 122 |
+
return {
|
| 123 |
+
"mean": float(sum(vals) / n),
|
| 124 |
+
"min": float(vals[0]),
|
| 125 |
+
"p50": float(q(0.5)),
|
| 126 |
+
"p90": float(q(0.9)),
|
| 127 |
+
"p99": float(q(0.99)),
|
| 128 |
+
"max": float(vals[-1]),
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def distribution_entropy_from_counts(counts: Counter[int]) -> float:
|
| 133 |
+
total = sum(counts.values())
|
| 134 |
+
if total <= 0:
|
| 135 |
+
return 0.0
|
| 136 |
+
out = 0.0
|
| 137 |
+
for c in counts.values():
|
| 138 |
+
p = c / total
|
| 139 |
+
out -= p * math.log(max(p, 1e-12))
|
| 140 |
+
return float(out)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def token_feature_rates(ids: torch.Tensor, tok: BpeTextTokenizer) -> dict[str, float]:
|
| 144 |
+
flat = [int(x) for x in ids.reshape(-1).tolist()]
|
| 145 |
+
if not flat:
|
| 146 |
+
return {}
|
| 147 |
+
pieces = [token_piece(tok, x) for x in flat]
|
| 148 |
+
texts = [token_text(tok, x) for x in flat]
|
| 149 |
+
specials = {tok.eos_id, tok.bos_id, tok.unk_id}
|
| 150 |
+
if tok.pad_id is not None:
|
| 151 |
+
specials.add(tok.pad_id)
|
| 152 |
+
denom = len(flat)
|
| 153 |
+
normal = [i for i, x in enumerate(flat) if x not in specials]
|
| 154 |
+
normal_denom = max(len(normal), 1)
|
| 155 |
+
return {
|
| 156 |
+
"bert_hash_rate": sum(pieces[i].startswith("##") for i in normal) / normal_denom,
|
| 157 |
+
"spm_cont_rate": sum((not pieces[i].startswith("▁")) and (not pieces[i].startswith("<")) for i in normal) / normal_denom,
|
| 158 |
+
"single_char_rate": sum(len(texts[i].strip()) == 1 for i in normal) / normal_denom,
|
| 159 |
+
"digit_piece_rate": sum(any(ch.isdigit() for ch in pieces[i]) for i in normal) / normal_denom,
|
| 160 |
+
"url_piece_rate": sum(("http" in pieces[i].lower() or "www" in pieces[i].lower() or ".com" in pieces[i].lower()) for i in normal) / normal_denom,
|
| 161 |
+
"special_rate": sum(x in specials for x in flat) / denom,
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def command_data(args: argparse.Namespace) -> None:
|
| 166 |
+
tok = BpeTextTokenizer.from_file(args.tokenizer_path)
|
| 167 |
+
batch = load_batch(
|
| 168 |
+
data_path=args.data_path,
|
| 169 |
+
tokenizer=tok,
|
| 170 |
+
max_len=args.max_len,
|
| 171 |
+
batch_size=args.n_sequences,
|
| 172 |
+
mode=args.data_mode,
|
| 173 |
+
text_column=args.text_column,
|
| 174 |
+
openwebtext_split=args.openwebtext_split,
|
| 175 |
+
wrap_mode=args.wrap_mode,
|
| 176 |
+
max_records=args.max_records,
|
| 177 |
+
tokenized_pad_token=args.tokenized_pad_token,
|
| 178 |
+
)
|
| 179 |
+
ids = batch["ids"]
|
| 180 |
+
attn = batch.get("attn_mask", torch.ones_like(ids, dtype=torch.bool))
|
| 181 |
+
valid_ids = ids[attn]
|
| 182 |
+
counts = Counter(int(x) for x in valid_ids.tolist())
|
| 183 |
+
top = [
|
| 184 |
+
{
|
| 185 |
+
"id": int(i),
|
| 186 |
+
"piece": compact_piece(token_piece(tok, int(i))),
|
| 187 |
+
"text": compact_piece(token_text(tok, int(i))),
|
| 188 |
+
"count": int(c),
|
| 189 |
+
"rate": float(c / max(valid_ids.numel(), 1)),
|
| 190 |
+
}
|
| 191 |
+
for i, c in counts.most_common(args.top_k)
|
| 192 |
+
]
|
| 193 |
+
seq_lens = attn.long().sum(dim=1).tolist()
|
| 194 |
+
internal = ids[:, 1:-1] if ids.size(1) > 2 else ids[:, :0]
|
| 195 |
+
internal_attn = attn[:, 1:-1] if attn.size(1) > 2 else attn[:, :0]
|
| 196 |
+
eos_internal = ((internal == tok.eos_id) & internal_attn).long().sum(dim=1).tolist()
|
| 197 |
+
pad_internal = []
|
| 198 |
+
if tok.pad_id is not None:
|
| 199 |
+
pad_internal = ((internal == tok.pad_id) & internal_attn).long().sum(dim=1).tolist()
|
| 200 |
+
pos0 = Counter(int(x) for x in ids[:, 0].tolist())
|
| 201 |
+
last_valid = []
|
| 202 |
+
for row, mask in zip(ids, attn):
|
| 203 |
+
idx = int(mask.long().sum().item()) - 1
|
| 204 |
+
if idx >= 0:
|
| 205 |
+
last_valid.append(int(row[idx].item()))
|
| 206 |
+
last_counts = Counter(last_valid)
|
| 207 |
+
record_lengths = list(
|
| 208 |
+
iter_record_lengths(
|
| 209 |
+
data_path=args.data_path,
|
| 210 |
+
tokenizer=tok,
|
| 211 |
+
mode=args.data_mode,
|
| 212 |
+
text_column=args.text_column,
|
| 213 |
+
openwebtext_split=args.openwebtext_split,
|
| 214 |
+
max_records=args.max_records,
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
out = {
|
| 218 |
+
"name": args.name,
|
| 219 |
+
"data_path": args.data_path,
|
| 220 |
+
"data_mode": args.data_mode,
|
| 221 |
+
"tokenizer_path": args.tokenizer_path,
|
| 222 |
+
"vocab_size": tok.vocab_size,
|
| 223 |
+
"bos_id": tok.bos_id,
|
| 224 |
+
"bos_piece": token_piece(tok, tok.bos_id),
|
| 225 |
+
"eos_id": tok.eos_id,
|
| 226 |
+
"eos_piece": token_piece(tok, tok.eos_id),
|
| 227 |
+
"pad_id": tok.pad_id,
|
| 228 |
+
"n_sequences": int(ids.size(0)),
|
| 229 |
+
"max_len": args.max_len,
|
| 230 |
+
"sequence_len": rate_summary([float(x) for x in seq_lens]),
|
| 231 |
+
"record_token_len_no_special_no_eos": rate_summary([float(x) for x in record_lengths]),
|
| 232 |
+
"internal_eos_per_seq": rate_summary([float(x) for x in eos_internal]),
|
| 233 |
+
"internal_pad_per_seq": rate_summary([float(x) for x in pad_internal]) if pad_internal else None,
|
| 234 |
+
"pos0_top": [
|
| 235 |
+
{"id": i, "piece": compact_piece(token_piece(tok, i)), "count": c, "rate": c / max(ids.size(0), 1)}
|
| 236 |
+
for i, c in pos0.most_common(args.top_k)
|
| 237 |
+
],
|
| 238 |
+
"last_valid_top": [
|
| 239 |
+
{"id": i, "piece": compact_piece(token_piece(tok, i)), "count": c, "rate": c / max(len(last_valid), 1)}
|
| 240 |
+
for i, c in last_counts.most_common(args.top_k)
|
| 241 |
+
],
|
| 242 |
+
"unigram_entropy": distribution_entropy_from_counts(counts),
|
| 243 |
+
"token_feature_rates": token_feature_rates(valid_ids, tok),
|
| 244 |
+
"top_unigram": top,
|
| 245 |
+
}
|
| 246 |
+
Path(args.out_json).parent.mkdir(parents=True, exist_ok=True)
|
| 247 |
+
Path(args.out_json).write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
|
| 248 |
+
print(json.dumps(out, indent=2, ensure_ascii=False), flush=True)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def ckpt_arg(ckpt_args: dict[str, Any], key: str, default: Any) -> Any:
|
| 252 |
+
return ckpt_args.get(key, default)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def make_bridge_for_eval(
|
| 256 |
+
*,
|
| 257 |
+
ids: torch.Tensor,
|
| 258 |
+
attn: torch.Tensor,
|
| 259 |
+
ckpt_args: dict[str, Any],
|
| 260 |
+
vocab_size: int,
|
| 261 |
+
t_value: float,
|
| 262 |
+
force_mask_ratio: float | None,
|
| 263 |
+
eps: float,
|
| 264 |
+
) -> Any:
|
| 265 |
+
return make_dirichlet_bridge_batch(
|
| 266 |
+
ids=ids,
|
| 267 |
+
attn_mask=attn,
|
| 268 |
+
vocab_size=vocab_size,
|
| 269 |
+
target_prob=float(ckpt_arg(ckpt_args, "target_prob", 1.0)),
|
| 270 |
+
min_t=float(ckpt_arg(ckpt_args, "min_t", 0.0)),
|
| 271 |
+
max_t=float(ckpt_arg(ckpt_args, "max_t", 1.0)),
|
| 272 |
+
min_mask_ratio=float(ckpt_arg(ckpt_args, "min_mask_ratio", 0.1)),
|
| 273 |
+
max_mask_ratio=float(ckpt_arg(ckpt_args, "max_mask_ratio", 1.0)),
|
| 274 |
+
wrong_token_replace_prob=ckpt_arg(ckpt_args, "wrong_token_replace_prob", "0.0"),
|
| 275 |
+
wrong_token_schedule=str(ckpt_arg(ckpt_args, "wrong_token_schedule", "constant")),
|
| 276 |
+
wrong_token_exp_k=float(ckpt_arg(ckpt_args, "wrong_token_exp_k", 1.0)),
|
| 277 |
+
dirichlet_concentration_min=float(ckpt_arg(ckpt_args, "dirichlet_concentration_min", 1.0)),
|
| 278 |
+
dirichlet_concentration_max=float(ckpt_arg(ckpt_args, "dirichlet_concentration_max", 1024.0)),
|
| 279 |
+
eps=eps,
|
| 280 |
+
state_format=str(ckpt_arg(ckpt_args, "state_format", ckpt_arg(ckpt_args, "input_format", "prob"))),
|
| 281 |
+
dirichlet_endpoint_mode=str(ckpt_arg(ckpt_args, "dirichlet_endpoint_mode", "bernoulli_wrong")),
|
| 282 |
+
dirichlet_semantic_t_mode=str(ckpt_arg(ckpt_args, "dirichlet_semantic_t_mode", "same")),
|
| 283 |
+
dirichlet_semantic_t_value=float(ckpt_arg(ckpt_args, "dirichlet_semantic_t_value", 0.0)),
|
| 284 |
+
dirichlet_semantic_t_curve=str(ckpt_arg(ckpt_args, "dirichlet_semantic_t_curve", "linear")),
|
| 285 |
+
dirichlet_semantic_t_power=float(ckpt_arg(ckpt_args, "dirichlet_semantic_t_power", 1.0)),
|
| 286 |
+
dirichlet_support_t_curve=str(ckpt_arg(ckpt_args, "dirichlet_support_t_curve", "linear")),
|
| 287 |
+
dirichlet_support_t_power=float(ckpt_arg(ckpt_args, "dirichlet_support_t_power", 1.0)),
|
| 288 |
+
endpoint_sequence_random_prob_alpha=float(ckpt_arg(ckpt_args, "endpoint_sequence_random_prob_alpha", 0.0)),
|
| 289 |
+
categorical_wrong_from_full_vocab=bool(ckpt_arg(ckpt_args, "categorical_wrong_from_full_vocab", False)),
|
| 290 |
+
categorical_wrong_from_batch_valid_tokens=bool(ckpt_arg(ckpt_args, "categorical_wrong_from_batch_valid_tokens", False)),
|
| 291 |
+
categorical_wrong_basin_token_ids=ckpt_arg(ckpt_args, "categorical_wrong_basin_token_ids", ""),
|
| 292 |
+
categorical_wrong_basin_prob=float(ckpt_arg(ckpt_args, "categorical_wrong_basin_prob", 0.0)),
|
| 293 |
+
categorical_wrong_unigram_prob=float(ckpt_arg(ckpt_args, "categorical_wrong_unigram_prob", 0.0)),
|
| 294 |
+
categorical_wrong_uniform_prob=float(ckpt_arg(ckpt_args, "categorical_wrong_uniform_prob", 0.0)),
|
| 295 |
+
categorical_wrong_prob_floor=float(ckpt_arg(ckpt_args, "categorical_wrong_prob_floor", 0.0)),
|
| 296 |
+
categorical_gold_prob_floor=float(ckpt_arg(ckpt_args, "categorical_gold_prob_floor", 0.0)),
|
| 297 |
+
categorical_gold_prob_ceil=float(ckpt_arg(ckpt_args, "categorical_gold_prob_ceil", 1.0)),
|
| 298 |
+
simplex_bridge_sampler=str(ckpt_arg(ckpt_args, "simplex_bridge_sampler", "dirichlet")),
|
| 299 |
+
logistic_normal_sigma_min=float(ckpt_arg(ckpt_args, "logistic_normal_sigma_min", 0.18)),
|
| 300 |
+
logistic_normal_sigma_max=float(ckpt_arg(ckpt_args, "logistic_normal_sigma_max", 2.2)),
|
| 301 |
+
logistic_normal_tau_min=float(ckpt_arg(ckpt_args, "logistic_normal_tau_min", 0.65)),
|
| 302 |
+
logistic_normal_tau_max=float(ckpt_arg(ckpt_args, "logistic_normal_tau_max", 1.15)),
|
| 303 |
+
force_t=t_value,
|
| 304 |
+
force_mask_ratio=force_mask_ratio,
|
| 305 |
+
mask_ratio_floor_schedule=str(ckpt_arg(ckpt_args, "mask_ratio_floor_schedule", "none")),
|
| 306 |
+
mask_mixture_original_prob=float(ckpt_arg(ckpt_args, "mask_mixture_original_prob", 0.0)),
|
| 307 |
+
mask_mixture_lowk_prob=float(ckpt_arg(ckpt_args, "mask_mixture_lowk_prob", 0.0)),
|
| 308 |
+
mask_mixture_lowcorrupt_prob=float(ckpt_arg(ckpt_args, "mask_mixture_lowcorrupt_prob", 0.0)),
|
| 309 |
+
mask_mixture_block_prob=float(ckpt_arg(ckpt_args, "mask_mixture_block_prob", 0.0)),
|
| 310 |
+
mask_mixture_all_prob=float(ckpt_arg(ckpt_args, "mask_mixture_all_prob", 0.0)),
|
| 311 |
+
mask_mixture_lowk_clean_tokens=ckpt_arg(ckpt_args, "mask_mixture_lowk_clean_tokens", "1,2,4,8,16,32,64"),
|
| 312 |
+
mask_mixture_lowcorrupt_tokens=ckpt_arg(ckpt_args, "mask_mixture_lowcorrupt_tokens", "1,2,4,8,16,32,64"),
|
| 313 |
+
mask_mixture_block_tokens=ckpt_arg(ckpt_args, "mask_mixture_block_tokens", "64,128"),
|
| 314 |
+
clean_state_mode=str(ckpt_arg(ckpt_args, "clean_state_mode", "onehot")),
|
| 315 |
+
return_dense_targets=False,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def masked_loss_acc(logits: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> dict[str, float]:
|
| 320 |
+
flat_mask = mask.reshape(-1)
|
| 321 |
+
if not bool(flat_mask.any().item()):
|
| 322 |
+
return {"nll": 0.0, "ppl": 1.0, "acc": 0.0, "tokens": 0}
|
| 323 |
+
flat_logits = logits.reshape(-1, logits.size(-1))[flat_mask]
|
| 324 |
+
flat_target = target.reshape(-1)[flat_mask]
|
| 325 |
+
loss = F.cross_entropy(flat_logits, flat_target, reduction="mean")
|
| 326 |
+
pred = flat_logits.argmax(dim=-1)
|
| 327 |
+
acc = (pred == flat_target).float().mean()
|
| 328 |
+
return {
|
| 329 |
+
"nll": float(loss.detach().cpu()),
|
| 330 |
+
"ppl": float(torch.exp(loss.clamp(max=50)).detach().cpu()),
|
| 331 |
+
"acc": float(acc.detach().cpu()),
|
| 332 |
+
"tokens": int(flat_mask.sum().detach().cpu()),
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
@torch.inference_mode()
|
| 337 |
+
def command_teacher(args: argparse.Namespace) -> None:
|
| 338 |
+
tok = BpeTextTokenizer.from_file(args.tokenizer_path)
|
| 339 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
|
| 340 |
+
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
| 341 |
+
ckpt_args = dict(ckpt.get("args", {}))
|
| 342 |
+
model = build_model_from_ckpt(ckpt, tok.vocab_size, args.max_len, device).eval()
|
| 343 |
+
batch = load_batch(
|
| 344 |
+
data_path=args.data_path,
|
| 345 |
+
tokenizer=tok,
|
| 346 |
+
max_len=args.max_len,
|
| 347 |
+
batch_size=args.batch_size,
|
| 348 |
+
mode=args.data_mode,
|
| 349 |
+
text_column=args.text_column,
|
| 350 |
+
openwebtext_split=args.openwebtext_split,
|
| 351 |
+
wrap_mode=args.wrap_mode,
|
| 352 |
+
max_records=args.max_records,
|
| 353 |
+
tokenized_pad_token=args.tokenized_pad_token,
|
| 354 |
+
)
|
| 355 |
+
ids = batch["ids"].to(device)
|
| 356 |
+
attn = batch.get("attn_mask", torch.ones_like(ids, dtype=torch.bool)).to(device)
|
| 357 |
+
rows = []
|
| 358 |
+
for t_value in [float(x) for x in args.t_values.split(",") if x.strip()]:
|
| 359 |
+
torch.manual_seed(args.seed + int(round(t_value * 1000000)))
|
| 360 |
+
bridge = make_bridge_for_eval(
|
| 361 |
+
ids=ids,
|
| 362 |
+
attn=attn,
|
| 363 |
+
ckpt_args=ckpt_args,
|
| 364 |
+
vocab_size=tok.vocab_size,
|
| 365 |
+
t_value=t_value,
|
| 366 |
+
force_mask_ratio=args.force_mask_ratio,
|
| 367 |
+
eps=args.eps,
|
| 368 |
+
)
|
| 369 |
+
model_t = bridge.t
|
| 370 |
+
logits = model(state_for_model(model, bridge.state, args.eps), model_t, attn).float()
|
| 371 |
+
valid = attn
|
| 372 |
+
corrupt = bridge.corrupt_mask & attn
|
| 373 |
+
pos0_pred = logits[:, 0].argmax(dim=-1)
|
| 374 |
+
last_pred = []
|
| 375 |
+
for b in range(ids.size(0)):
|
| 376 |
+
last = int(attn[b].long().sum().item()) - 1
|
| 377 |
+
last_pred.append(int(logits[b, last].argmax().detach().cpu()) if last >= 0 else -1)
|
| 378 |
+
pos0_counts = Counter(int(x) for x in pos0_pred.detach().cpu().tolist())
|
| 379 |
+
last_counts = Counter(last_pred)
|
| 380 |
+
probs = F.softmax(logits, dim=-1)
|
| 381 |
+
rows.append(
|
| 382 |
+
{
|
| 383 |
+
"name": args.name,
|
| 384 |
+
"checkpoint": args.checkpoint,
|
| 385 |
+
"ckpt_step": int(ckpt.get("step", -1)),
|
| 386 |
+
"t": t_value,
|
| 387 |
+
"force_mask_ratio": args.force_mask_ratio,
|
| 388 |
+
"corrupt_frac": float(corrupt.float().mean().detach().cpu()),
|
| 389 |
+
"wrong_frac": float((bridge.wrong_mask & attn).float().sum().detach().cpu() / attn.float().sum().clamp_min(1).detach().cpu()),
|
| 390 |
+
"valid": masked_loss_acc(logits, ids, valid),
|
| 391 |
+
"corrupt": masked_loss_acc(logits, ids, corrupt),
|
| 392 |
+
"dist_entropy": float((-(probs.clamp_min(args.eps) * probs.clamp_min(args.eps).log()).sum(dim=-1)[valid]).mean().detach().cpu()),
|
| 393 |
+
"mean_maxp": float(probs.max(dim=-1).values[valid].mean().detach().cpu()),
|
| 394 |
+
"pos0_gold_id": int(ids[0, 0].detach().cpu()),
|
| 395 |
+
"pos0_gold_piece": token_piece(tok, int(ids[0, 0].detach().cpu())),
|
| 396 |
+
"pos0_top": [
|
| 397 |
+
{"id": i, "piece": compact_piece(token_piece(tok, i)), "count": c, "rate": c / max(ids.size(0), 1)}
|
| 398 |
+
for i, c in pos0_counts.most_common(5)
|
| 399 |
+
],
|
| 400 |
+
"last_top": [
|
| 401 |
+
{"id": i, "piece": compact_piece(token_piece(tok, i)), "count": c, "rate": c / max(ids.size(0), 1)}
|
| 402 |
+
for i, c in last_counts.most_common(5)
|
| 403 |
+
],
|
| 404 |
+
}
|
| 405 |
+
)
|
| 406 |
+
out = Path(args.out_json)
|
| 407 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 408 |
+
out.write_text(json.dumps(rows, indent=2, ensure_ascii=False), encoding="utf-8")
|
| 409 |
+
with out.with_suffix(".tsv").open("w", newline="", encoding="utf-8") as f:
|
| 410 |
+
fields = [
|
| 411 |
+
"name",
|
| 412 |
+
"ckpt_step",
|
| 413 |
+
"t",
|
| 414 |
+
"force_mask_ratio",
|
| 415 |
+
"corrupt_frac",
|
| 416 |
+
"wrong_frac",
|
| 417 |
+
"valid_nll",
|
| 418 |
+
"valid_acc",
|
| 419 |
+
"corrupt_nll",
|
| 420 |
+
"corrupt_acc",
|
| 421 |
+
"dist_entropy",
|
| 422 |
+
"mean_maxp",
|
| 423 |
+
"pos0_gold_piece",
|
| 424 |
+
"pos0_top",
|
| 425 |
+
"last_top",
|
| 426 |
+
]
|
| 427 |
+
writer = csv.DictWriter(f, fieldnames=fields, delimiter="\t")
|
| 428 |
+
writer.writeheader()
|
| 429 |
+
for row in rows:
|
| 430 |
+
writer.writerow(
|
| 431 |
+
{
|
| 432 |
+
"name": row["name"],
|
| 433 |
+
"ckpt_step": row["ckpt_step"],
|
| 434 |
+
"t": row["t"],
|
| 435 |
+
"force_mask_ratio": row["force_mask_ratio"],
|
| 436 |
+
"corrupt_frac": row["corrupt_frac"],
|
| 437 |
+
"wrong_frac": row["wrong_frac"],
|
| 438 |
+
"valid_nll": row["valid"]["nll"],
|
| 439 |
+
"valid_acc": row["valid"]["acc"],
|
| 440 |
+
"corrupt_nll": row["corrupt"]["nll"],
|
| 441 |
+
"corrupt_acc": row["corrupt"]["acc"],
|
| 442 |
+
"dist_entropy": row["dist_entropy"],
|
| 443 |
+
"mean_maxp": row["mean_maxp"],
|
| 444 |
+
"pos0_gold_piece": row["pos0_gold_piece"],
|
| 445 |
+
"pos0_top": " | ".join(f"{x['piece']}:{x['rate']:.2f}" for x in row["pos0_top"]),
|
| 446 |
+
"last_top": " | ".join(f"{x['piece']}:{x['rate']:.2f}" for x in row["last_top"]),
|
| 447 |
+
}
|
| 448 |
+
)
|
| 449 |
+
for row in rows:
|
| 450 |
+
print(
|
| 451 |
+
f"{row['name']} step={row['ckpt_step']} t={row['t']:.4f} "
|
| 452 |
+
f"valid_nll={row['valid']['nll']:.3f} valid_acc={row['valid']['acc']:.3f} "
|
| 453 |
+
f"corrupt_nll={row['corrupt']['nll']:.3f} corrupt_acc={row['corrupt']['acc']:.3f} "
|
| 454 |
+
f"pos0={row['pos0_top'][0]['piece']}:{row['pos0_top'][0]['rate']:.2f}",
|
| 455 |
+
flush=True,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def filter_top_p(probs: torch.Tensor, top_p: float, eps: float) -> torch.Tensor:
|
| 460 |
+
if top_p >= 1.0:
|
| 461 |
+
return probs
|
| 462 |
+
sorted_vals, sorted_idx = torch.sort(probs, dim=-1, descending=True)
|
| 463 |
+
total = sorted_vals.sum(dim=-1, keepdim=True).clamp_min(eps)
|
| 464 |
+
remove = sorted_vals.cumsum(dim=-1) > top_p * total
|
| 465 |
+
remove[..., 0] = False
|
| 466 |
+
sorted_vals = sorted_vals.masked_fill(remove, 0.0)
|
| 467 |
+
out = torch.zeros_like(probs).scatter(-1, sorted_idx, sorted_vals)
|
| 468 |
+
return out / out.sum(dim=-1, keepdim=True).clamp_min(eps)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def distribution_metrics(probs: torch.Tensor, ids: torch.Tensor, tok: BpeTextTokenizer, prefix: str) -> dict[str, Any]:
|
| 472 |
+
p = probs.clamp_min(1e-12)
|
| 473 |
+
ent = float((-(p * p.log()).sum(dim=-1)).mean().detach().cpu())
|
| 474 |
+
maxp, arg = probs.max(dim=-1)
|
| 475 |
+
counts = Counter(int(x) for x in arg.reshape(-1).detach().cpu().tolist())
|
| 476 |
+
return {
|
| 477 |
+
f"{prefix}_entropy": ent,
|
| 478 |
+
f"{prefix}_mean_top_mass": float(maxp.mean().detach().cpu()),
|
| 479 |
+
f"{prefix}_argmax_token_entropy": distribution_entropy_from_counts(counts),
|
| 480 |
+
f"{prefix}_argmax_top": [
|
| 481 |
+
{"id": i, "piece": compact_piece(token_piece(tok, i)), "count": c, "rate": c / max(arg.numel(), 1)}
|
| 482 |
+
for i, c in counts.most_common(8)
|
| 483 |
+
],
|
| 484 |
+
**{f"{prefix}_{k}": v for k, v in token_feature_rates(arg.detach().cpu(), tok).items()},
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
@torch.inference_mode()
|
| 489 |
+
def command_trace(args: argparse.Namespace) -> None:
|
| 490 |
+
tok = BpeTextTokenizer.from_file(args.tokenizer_path)
|
| 491 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
|
| 492 |
+
torch.manual_seed(args.seed)
|
| 493 |
+
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
| 494 |
+
model = build_model_from_ckpt(ckpt, tok.vocab_size, args.max_len, device).eval()
|
| 495 |
+
eps = args.eps
|
| 496 |
+
bs = args.batch_size
|
| 497 |
+
probs = sample_noise_simplex(
|
| 498 |
+
(bs, args.max_len),
|
| 499 |
+
tok.vocab_size,
|
| 500 |
+
device,
|
| 501 |
+
eps,
|
| 502 |
+
noise_mode="dirichlet",
|
| 503 |
+
target_prob=1.0,
|
| 504 |
+
noise_sigma=-1.0,
|
| 505 |
+
dirichlet_concentration=args.concentration_min,
|
| 506 |
+
)
|
| 507 |
+
attn = torch.ones((bs, args.max_len), dtype=torch.bool, device=device)
|
| 508 |
+
log_cmin = math.log(args.concentration_min)
|
| 509 |
+
log_cmax = math.log(args.concentration_max)
|
| 510 |
+
out = Path(args.out_jsonl)
|
| 511 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 512 |
+
snapshot = set(int(x) for x in args.trace_steps.split(",") if x.strip())
|
| 513 |
+
last_endpoint = probs
|
| 514 |
+
with out.open("w", encoding="utf-8") as f:
|
| 515 |
+
for step in range(args.steps):
|
| 516 |
+
support_t = (step + 1) / max(args.steps, 1)
|
| 517 |
+
t = torch.full((bs,), support_t, dtype=torch.float32, device=device)
|
| 518 |
+
logits = model(state_for_model(model, probs, eps), t, attn).float()
|
| 519 |
+
endpoint = F.softmax(logits / args.endpoint_temp, dim=-1)
|
| 520 |
+
endpoint = filter_top_p(endpoint, args.endpoint_top_p, eps)
|
| 521 |
+
tau = args.gumbel_tau_start + support_t * (args.gumbel_tau_end - args.gumbel_tau_start)
|
| 522 |
+
uniform = torch.rand_like(endpoint).clamp_(eps, 1.0 - eps)
|
| 523 |
+
gumbel = -torch.log(-torch.log(uniform))
|
| 524 |
+
projected = F.softmax((endpoint.clamp_min(eps).log() + gumbel) / max(tau, eps), dim=-1)
|
| 525 |
+
last_endpoint = projected
|
| 526 |
+
mean = (1.0 - support_t) / tok.vocab_size + support_t * projected
|
| 527 |
+
mean = mean / mean.sum(dim=-1, keepdim=True).clamp_min(eps)
|
| 528 |
+
conc = math.exp(log_cmin + support_t * (log_cmax - log_cmin))
|
| 529 |
+
alpha = (mean * conc).clamp_min(eps)
|
| 530 |
+
probs = torch._standard_gamma(alpha).clamp_min(eps)
|
| 531 |
+
probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(eps)
|
| 532 |
+
step_num = step + 1
|
| 533 |
+
if step_num in snapshot or step_num == args.steps:
|
| 534 |
+
row = {
|
| 535 |
+
"name": args.name,
|
| 536 |
+
"ckpt_step": int(ckpt.get("step", -1)),
|
| 537 |
+
"step": step_num,
|
| 538 |
+
"support_t": support_t,
|
| 539 |
+
"tau": tau,
|
| 540 |
+
"concentration": conc,
|
| 541 |
+
}
|
| 542 |
+
row.update(distribution_metrics(endpoint, endpoint.argmax(dim=-1), tok, "a"))
|
| 543 |
+
row.update(distribution_metrics(projected, projected.argmax(dim=-1), tok, "e"))
|
| 544 |
+
row.update(distribution_metrics(probs, probs.argmax(dim=-1), tok, "p"))
|
| 545 |
+
for pos in [0, 1, args.max_len - 2, args.max_len - 1]:
|
| 546 |
+
a_id = int(endpoint[0, pos].argmax().detach().cpu())
|
| 547 |
+
e_id = int(projected[0, pos].argmax().detach().cpu())
|
| 548 |
+
p_id = int(probs[0, pos].argmax().detach().cpu())
|
| 549 |
+
row[f"pos{pos}_a"] = {"id": a_id, "piece": compact_piece(token_piece(tok, a_id)), "prob": float(endpoint[0, pos, a_id].detach().cpu())}
|
| 550 |
+
row[f"pos{pos}_e"] = {"id": e_id, "piece": compact_piece(token_piece(tok, e_id)), "prob": float(projected[0, pos, e_id].detach().cpu())}
|
| 551 |
+
row[f"pos{pos}_p"] = {"id": p_id, "piece": compact_piece(token_piece(tok, p_id)), "prob": float(probs[0, pos, p_id].detach().cpu())}
|
| 552 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 553 |
+
print(
|
| 554 |
+
f"{args.name} step={step_num} aH={row['a_entropy']:.2f} eH={row['e_entropy']:.2f} pH={row['p_entropy']:.2f} "
|
| 555 |
+
f"a_top={row['a_argmax_top'][0]['piece']}:{row['a_argmax_top'][0]['rate']:.2f} "
|
| 556 |
+
f"p_top={row['p_argmax_top'][0]['piece']}:{row['p_argmax_top'][0]['rate']:.2f}",
|
| 557 |
+
flush=True,
|
| 558 |
+
)
|
| 559 |
+
if args.final_out:
|
| 560 |
+
final_probs = 0.5 * probs + 0.5 * last_endpoint
|
| 561 |
+
ids = final_probs.argmax(dim=-1).detach().cpu().tolist()
|
| 562 |
+
Path(args.final_out).write_text("\n\n".join(tok.decode(row, stop_at_eos=False, skip_special_tokens=False) for row in ids), encoding="utf-8")
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
def main() -> None:
|
| 566 |
+
ap = argparse.ArgumentParser()
|
| 567 |
+
sub = ap.add_subparsers(dest="cmd", required=True)
|
| 568 |
+
data = sub.add_parser("data")
|
| 569 |
+
data.add_argument("--name", required=True)
|
| 570 |
+
data.add_argument("--data_path", required=True)
|
| 571 |
+
data.add_argument("--tokenizer_path", required=True)
|
| 572 |
+
data.add_argument("--out_json", required=True)
|
| 573 |
+
data.add_argument("--data_mode", choices=["wrap", "tokenized_hf"], default="wrap")
|
| 574 |
+
data.add_argument("--text_column", default=None)
|
| 575 |
+
data.add_argument("--openwebtext_split", default="all")
|
| 576 |
+
data.add_argument("--wrap_mode", default="stream")
|
| 577 |
+
data.add_argument("--tokenized_pad_token", default="pad")
|
| 578 |
+
data.add_argument("--max_len", type=int, default=1024)
|
| 579 |
+
data.add_argument("--n_sequences", type=int, default=2048)
|
| 580 |
+
data.add_argument("--max_records", type=int, default=20000)
|
| 581 |
+
data.add_argument("--top_k", type=int, default=24)
|
| 582 |
+
data.set_defaults(func=command_data)
|
| 583 |
+
|
| 584 |
+
teacher = sub.add_parser("teacher")
|
| 585 |
+
teacher.add_argument("--name", required=True)
|
| 586 |
+
teacher.add_argument("--checkpoint", required=True)
|
| 587 |
+
teacher.add_argument("--data_path", required=True)
|
| 588 |
+
teacher.add_argument("--tokenizer_path", required=True)
|
| 589 |
+
teacher.add_argument("--out_json", required=True)
|
| 590 |
+
teacher.add_argument("--data_mode", choices=["wrap", "tokenized_hf"], default="wrap")
|
| 591 |
+
teacher.add_argument("--text_column", default=None)
|
| 592 |
+
teacher.add_argument("--openwebtext_split", default="all")
|
| 593 |
+
teacher.add_argument("--wrap_mode", default="stream")
|
| 594 |
+
teacher.add_argument("--tokenized_pad_token", default="pad")
|
| 595 |
+
teacher.add_argument("--max_len", type=int, default=1024)
|
| 596 |
+
teacher.add_argument("--batch_size", type=int, default=8)
|
| 597 |
+
teacher.add_argument("--max_records", type=int, default=20000)
|
| 598 |
+
teacher.add_argument("--t_values", default="0.0,0.0078125,0.03125,0.125,0.5,1.0")
|
| 599 |
+
teacher.add_argument("--force_mask_ratio", type=float, default=None)
|
| 600 |
+
teacher.add_argument("--seed", type=int, default=20260525)
|
| 601 |
+
teacher.add_argument("--eps", type=float, default=1e-8)
|
| 602 |
+
teacher.add_argument("--cpu", action="store_true")
|
| 603 |
+
teacher.set_defaults(func=command_teacher)
|
| 604 |
+
|
| 605 |
+
trace = sub.add_parser("trace")
|
| 606 |
+
trace.add_argument("--name", required=True)
|
| 607 |
+
trace.add_argument("--checkpoint", required=True)
|
| 608 |
+
trace.add_argument("--tokenizer_path", required=True)
|
| 609 |
+
trace.add_argument("--out_jsonl", required=True)
|
| 610 |
+
trace.add_argument("--final_out", default="")
|
| 611 |
+
trace.add_argument("--max_len", type=int, default=1024)
|
| 612 |
+
trace.add_argument("--batch_size", type=int, default=2)
|
| 613 |
+
trace.add_argument("--steps", type=int, default=128)
|
| 614 |
+
trace.add_argument("--trace_steps", default="1,2,4,8,16,32,64,96,128")
|
| 615 |
+
trace.add_argument("--concentration_min", type=float, default=30522)
|
| 616 |
+
trace.add_argument("--concentration_max", type=float, default=61044)
|
| 617 |
+
trace.add_argument("--endpoint_temp", type=float, default=1.45)
|
| 618 |
+
trace.add_argument("--endpoint_top_p", type=float, default=0.95)
|
| 619 |
+
trace.add_argument("--gumbel_tau_start", type=float, default=1.0)
|
| 620 |
+
trace.add_argument("--gumbel_tau_end", type=float, default=0.2)
|
| 621 |
+
trace.add_argument("--seed", type=int, default=20260525)
|
| 622 |
+
trace.add_argument("--eps", type=float, default=1e-8)
|
| 623 |
+
trace.add_argument("--cpu", action="store_true")
|
| 624 |
+
trace.set_defaults(func=command_trace)
|
| 625 |
+
|
| 626 |
+
args = ap.parse_args()
|
| 627 |
+
args.func(args)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
if __name__ == "__main__":
|
| 631 |
+
main()
|
LTA_openwebtext_dualt/scripts/build_lta_owt_compact_gpt2bpe_stream1024_train_minus_100k_np8.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export PACKING_MODE="${PACKING_MODE:-stream_chunks}"
|
| 7 |
+
export OUTPUT_SUFFIX="${OUTPUT_SUFFIX:-stream1024}"
|
| 8 |
+
export CACHE_SUFFIX="${CACHE_SUFFIX:-_stream1024}"
|
| 9 |
+
export LOG_DIR="${LOG_DIR:-logs/data_build_compact_gpt2bpe_stream1024}"
|
| 10 |
+
export VOCAB_SIZES="${VOCAB_SIZES:-2048,4096,8192}"
|
| 11 |
+
export NUM_PROC="${NUM_PROC:-8}"
|
| 12 |
+
|
| 13 |
+
exec bash scripts/build_lta_owt_compact_gpt2bpe_packed_train_minus_100k_np8.sh "$@"
|
LTA_openwebtext_dualt/scripts/build_owt_t5_elf_dataset.py
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Iterator
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def parse_args() -> argparse.Namespace:
|
| 13 |
+
p = argparse.ArgumentParser(
|
| 14 |
+
description=(
|
| 15 |
+
"Build an ELF-style OpenWebText T5 token dataset. By default each raw "
|
| 16 |
+
"record is tokenized with add_special_tokens=False, overlength records "
|
| 17 |
+
"are split into max_len chunks, and short records stay short. The "
|
| 18 |
+
"packed_records mode instead concatenates EOS-terminated records up to "
|
| 19 |
+
"max_len while preserving record boundaries. stream_chunks concatenates "
|
| 20 |
+
"the token stream and slices exact max_len chunks, so chunk boundaries "
|
| 21 |
+
"are defined by the selected tokenizer."
|
| 22 |
+
)
|
| 23 |
+
)
|
| 24 |
+
p.add_argument("--data_path", required=True)
|
| 25 |
+
p.add_argument("--output_dir", required=True)
|
| 26 |
+
p.add_argument("--tokenizer_path", required=True)
|
| 27 |
+
p.add_argument("--text_column", default="text")
|
| 28 |
+
p.add_argument("--txt_record_mode", choices=["auto", "line", "eot"], default="auto")
|
| 29 |
+
p.add_argument("--openwebtext_split", choices=["all", "train_minus_100k", "valid_last_100k"], default="all")
|
| 30 |
+
p.add_argument("--openwebtext_valid_records", type=int, default=100_000)
|
| 31 |
+
p.add_argument("--detokenizer", default="auto")
|
| 32 |
+
p.add_argument("--max_len", type=int, default=1024)
|
| 33 |
+
p.add_argument(
|
| 34 |
+
"--packing_mode",
|
| 35 |
+
choices=["record_chunks", "packed_records", "stream_chunks"],
|
| 36 |
+
default="record_chunks",
|
| 37 |
+
help=(
|
| 38 |
+
"record_chunks preserves the old behavior. packed_records appends EOS "
|
| 39 |
+
"per record and packs multiple records into near-max_len examples. "
|
| 40 |
+
"stream_chunks appends EOS per record, concatenates records, and emits "
|
| 41 |
+
"exact max_len chunks across record boundaries."
|
| 42 |
+
),
|
| 43 |
+
)
|
| 44 |
+
p.add_argument("--max_records", type=int, default=0)
|
| 45 |
+
p.add_argument("--min_len", type=int, default=1)
|
| 46 |
+
p.add_argument("--add_eos", action="store_true", help="Append tokenizer EOS to each raw record before chunking.")
|
| 47 |
+
p.add_argument("--add_special_tokens", action="store_true", help="Let the tokenizer add model special tokens.")
|
| 48 |
+
p.add_argument("--cache_dir", default="")
|
| 49 |
+
p.add_argument("--max_shard_size", default="500MB")
|
| 50 |
+
p.add_argument("--num_proc", type=int, default=max(1, min(32, (os.cpu_count() or 8) // 2)))
|
| 51 |
+
p.add_argument("--tokenize_batch_size", type=int, default=1024)
|
| 52 |
+
p.add_argument(
|
| 53 |
+
"--merge_parts",
|
| 54 |
+
action="store_true",
|
| 55 |
+
help="After parallel part build, merge into one save_to_disk dataset. Slower but portable.",
|
| 56 |
+
)
|
| 57 |
+
p.add_argument("--keep_parts", action="store_true")
|
| 58 |
+
p.add_argument("--resume_parts", action="store_true", help="Keep completed part-* directories and build only missing parts.")
|
| 59 |
+
p.add_argument("--stats_only", action="store_true")
|
| 60 |
+
p.add_argument("--overwrite", action="store_true")
|
| 61 |
+
return p.parse_args()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _iter_examples(
|
| 65 |
+
*,
|
| 66 |
+
data_path: str,
|
| 67 |
+
tokenizer_path: str,
|
| 68 |
+
text_column: str | None,
|
| 69 |
+
txt_record_mode: str,
|
| 70 |
+
openwebtext_split: str,
|
| 71 |
+
openwebtext_valid_records: int,
|
| 72 |
+
detokenizer: str | None,
|
| 73 |
+
max_len: int,
|
| 74 |
+
packing_mode: str,
|
| 75 |
+
max_records: int,
|
| 76 |
+
min_len: int,
|
| 77 |
+
add_eos: bool,
|
| 78 |
+
add_special_tokens: bool,
|
| 79 |
+
) -> Iterator[dict]:
|
| 80 |
+
from flowtext_lab.data import iter_text_records
|
| 81 |
+
from flowtext_lab.tokenization import BpeTextTokenizer
|
| 82 |
+
|
| 83 |
+
tokenizer = BpeTextTokenizer.from_file(tokenizer_path)
|
| 84 |
+
seen_records = 0
|
| 85 |
+
pack: list[int] = []
|
| 86 |
+
|
| 87 |
+
def emit_ids(ids: list[int]) -> dict:
|
| 88 |
+
return {
|
| 89 |
+
"input_ids": [int(x) for x in ids],
|
| 90 |
+
"sequence_length": int(len(ids)),
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
def iter_record_chunks(ids: list[int]) -> Iterator[dict]:
|
| 94 |
+
for start in range(0, len(ids), max_len):
|
| 95 |
+
chunk = ids[start : start + max_len]
|
| 96 |
+
if len(chunk) >= min_len:
|
| 97 |
+
yield emit_ids(chunk)
|
| 98 |
+
if start + max_len >= len(ids):
|
| 99 |
+
break
|
| 100 |
+
|
| 101 |
+
def flush_pack() -> Iterator[dict]:
|
| 102 |
+
nonlocal pack
|
| 103 |
+
if len(pack) >= min_len:
|
| 104 |
+
yield emit_ids(pack)
|
| 105 |
+
pack = []
|
| 106 |
+
|
| 107 |
+
def append_stream(ids: list[int]) -> Iterator[dict]:
|
| 108 |
+
nonlocal pack
|
| 109 |
+
pack.extend(int(x) for x in ids)
|
| 110 |
+
while len(pack) >= max_len:
|
| 111 |
+
yield emit_ids(pack[:max_len])
|
| 112 |
+
pack = pack[max_len:]
|
| 113 |
+
|
| 114 |
+
for text in iter_text_records(
|
| 115 |
+
data_path,
|
| 116 |
+
text_column=text_column,
|
| 117 |
+
txt_record_mode=txt_record_mode,
|
| 118 |
+
openwebtext_split=openwebtext_split,
|
| 119 |
+
openwebtext_valid_records=openwebtext_valid_records,
|
| 120 |
+
detokenizer=detokenizer,
|
| 121 |
+
):
|
| 122 |
+
if not text:
|
| 123 |
+
continue
|
| 124 |
+
ids = tokenizer.encode(text, add_eos=add_eos, add_special_tokens=add_special_tokens)
|
| 125 |
+
if not ids:
|
| 126 |
+
continue
|
| 127 |
+
if packing_mode == "record_chunks":
|
| 128 |
+
yield from iter_record_chunks(ids)
|
| 129 |
+
elif packing_mode == "packed_records":
|
| 130 |
+
if len(ids) > max_len:
|
| 131 |
+
yield from flush_pack()
|
| 132 |
+
yield from iter_record_chunks(ids)
|
| 133 |
+
else:
|
| 134 |
+
if pack and len(pack) + len(ids) > max_len:
|
| 135 |
+
yield from flush_pack()
|
| 136 |
+
pack.extend(int(x) for x in ids)
|
| 137 |
+
if len(pack) >= max_len:
|
| 138 |
+
yield from flush_pack()
|
| 139 |
+
else:
|
| 140 |
+
yield from append_stream(ids)
|
| 141 |
+
seen_records += 1
|
| 142 |
+
if max_records > 0 and seen_records >= max_records:
|
| 143 |
+
break
|
| 144 |
+
if packing_mode in ("packed_records", "stream_chunks"):
|
| 145 |
+
yield from flush_pack()
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _stats(args: argparse.Namespace) -> dict:
|
| 149 |
+
num_examples = 0
|
| 150 |
+
total_tokens = 0
|
| 151 |
+
min_len = None
|
| 152 |
+
max_len = 0
|
| 153 |
+
hist = {"lt128": 0, "128_255": 0, "256_511": 0, "512_1023": 0, "eq1024": 0}
|
| 154 |
+
for ex in _iter_examples(**_gen_kwargs(args)):
|
| 155 |
+
length = int(ex["sequence_length"])
|
| 156 |
+
num_examples += 1
|
| 157 |
+
total_tokens += length
|
| 158 |
+
min_len = length if min_len is None else min(min_len, length)
|
| 159 |
+
max_len = max(max_len, length)
|
| 160 |
+
if length < 128:
|
| 161 |
+
hist["lt128"] += 1
|
| 162 |
+
elif length < 256:
|
| 163 |
+
hist["128_255"] += 1
|
| 164 |
+
elif length < 512:
|
| 165 |
+
hist["256_511"] += 1
|
| 166 |
+
elif length < args.max_len:
|
| 167 |
+
hist["512_1023"] += 1
|
| 168 |
+
else:
|
| 169 |
+
hist["eq1024"] += 1
|
| 170 |
+
return {
|
| 171 |
+
"num_examples": int(num_examples),
|
| 172 |
+
"total_tokens": int(total_tokens),
|
| 173 |
+
"mean_length": float(total_tokens / num_examples) if num_examples else 0.0,
|
| 174 |
+
"min_length": int(min_len or 0),
|
| 175 |
+
"max_length": int(max_len),
|
| 176 |
+
"length_hist": hist,
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _gen_kwargs(args: argparse.Namespace) -> dict:
|
| 181 |
+
return {
|
| 182 |
+
"data_path": args.data_path,
|
| 183 |
+
"tokenizer_path": args.tokenizer_path,
|
| 184 |
+
"text_column": args.text_column,
|
| 185 |
+
"txt_record_mode": args.txt_record_mode,
|
| 186 |
+
"openwebtext_split": args.openwebtext_split,
|
| 187 |
+
"openwebtext_valid_records": args.openwebtext_valid_records,
|
| 188 |
+
"detokenizer": args.detokenizer,
|
| 189 |
+
"max_len": int(args.max_len),
|
| 190 |
+
"packing_mode": args.packing_mode,
|
| 191 |
+
"max_records": int(args.max_records),
|
| 192 |
+
"min_len": int(args.min_len),
|
| 193 |
+
"add_eos": bool(args.add_eos),
|
| 194 |
+
"add_special_tokens": bool(args.add_special_tokens),
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _make_limited_specs(args: argparse.Namespace) -> list[tuple[str, int, int | None]]:
|
| 199 |
+
from flowtext_lab.data import _make_file_specs
|
| 200 |
+
|
| 201 |
+
root = Path(args.data_path)
|
| 202 |
+
if root.is_dir():
|
| 203 |
+
files = sorted(
|
| 204 |
+
p for p in root.rglob("*")
|
| 205 |
+
if p.suffix.lower() in {".txt", ".jsonl", ".json", ".parquet"}
|
| 206 |
+
)
|
| 207 |
+
else:
|
| 208 |
+
files = [root]
|
| 209 |
+
specs = _make_file_specs(files, args.openwebtext_split, int(args.openwebtext_valid_records))
|
| 210 |
+
if args.max_records <= 0:
|
| 211 |
+
return [(str(p), int(a), None if b is None else int(b)) for p, a, b in specs]
|
| 212 |
+
|
| 213 |
+
limited = []
|
| 214 |
+
remaining = int(args.max_records)
|
| 215 |
+
for path, start, stop in specs:
|
| 216 |
+
if remaining <= 0:
|
| 217 |
+
break
|
| 218 |
+
if stop is None:
|
| 219 |
+
limited.append((str(path), int(start), None))
|
| 220 |
+
break
|
| 221 |
+
count = max(0, int(stop) - int(start))
|
| 222 |
+
take = min(count, remaining)
|
| 223 |
+
if take > 0:
|
| 224 |
+
limited.append((str(path), int(start), int(start) + take))
|
| 225 |
+
remaining -= take
|
| 226 |
+
return limited
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _iter_parquet_text_batches(
|
| 230 |
+
path: Path,
|
| 231 |
+
*,
|
| 232 |
+
text_column: str | None,
|
| 233 |
+
row_start: int,
|
| 234 |
+
row_stop: int | None,
|
| 235 |
+
batch_size: int,
|
| 236 |
+
) -> Iterator[list[str]]:
|
| 237 |
+
import pyarrow.parquet as pq
|
| 238 |
+
|
| 239 |
+
pf = pq.ParquetFile(path)
|
| 240 |
+
col = text_column
|
| 241 |
+
if col is None:
|
| 242 |
+
names = set(pf.schema_arrow.names)
|
| 243 |
+
col = next((c for c in ("text", "content", "document", "article", "sentence") if c in names), None)
|
| 244 |
+
if col is None:
|
| 245 |
+
raise ValueError(f"Could not infer text column for {path}")
|
| 246 |
+
|
| 247 |
+
offset = 0
|
| 248 |
+
stop = pf.metadata.num_rows if row_stop is None else min(row_stop, pf.metadata.num_rows)
|
| 249 |
+
for batch in pf.iter_batches(columns=[col], batch_size=batch_size):
|
| 250 |
+
batch_start = offset
|
| 251 |
+
batch_stop = offset + batch.num_rows
|
| 252 |
+
offset = batch_stop
|
| 253 |
+
if batch_stop <= row_start:
|
| 254 |
+
continue
|
| 255 |
+
if batch_start >= stop:
|
| 256 |
+
break
|
| 257 |
+
local_start = max(0, row_start - batch_start)
|
| 258 |
+
local_stop = min(batch.num_rows, stop - batch_start)
|
| 259 |
+
values = batch.column(0).slice(local_start, local_stop - local_start).to_pylist()
|
| 260 |
+
texts = [str(value) for value in values if value is not None and str(value)]
|
| 261 |
+
if texts:
|
| 262 |
+
yield texts
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _iter_part_examples(
|
| 266 |
+
*,
|
| 267 |
+
spec: tuple[str, int, int | None],
|
| 268 |
+
tokenizer_path: str,
|
| 269 |
+
text_column: str | None,
|
| 270 |
+
detokenizer: str | None,
|
| 271 |
+
max_len: int,
|
| 272 |
+
packing_mode: str,
|
| 273 |
+
min_len: int,
|
| 274 |
+
add_eos: bool,
|
| 275 |
+
add_special_tokens: bool,
|
| 276 |
+
tokenize_batch_size: int,
|
| 277 |
+
) -> Iterator[dict]:
|
| 278 |
+
from flowtext_lab.text_detokenization import detokenize_text, infer_detokenizer_name
|
| 279 |
+
from flowtext_lab.tokenization import BpeTextTokenizer
|
| 280 |
+
|
| 281 |
+
path = Path(spec[0])
|
| 282 |
+
row_start = int(spec[1])
|
| 283 |
+
row_stop = None if spec[2] is None else int(spec[2])
|
| 284 |
+
tokenizer = BpeTextTokenizer.from_file(tokenizer_path)
|
| 285 |
+
resolved_detok = infer_detokenizer_name(raw_path=str(path), explicit=detokenizer)
|
| 286 |
+
pack: list[int] = []
|
| 287 |
+
|
| 288 |
+
def emit_ids(ids: list[int]) -> dict:
|
| 289 |
+
return {
|
| 290 |
+
"input_ids": [int(x) for x in ids],
|
| 291 |
+
"sequence_length": int(len(ids)),
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
def iter_record_chunks(ids: list[int]) -> Iterator[dict]:
|
| 295 |
+
for start in range(0, len(ids), max_len):
|
| 296 |
+
chunk = ids[start : start + max_len]
|
| 297 |
+
if len(chunk) >= min_len:
|
| 298 |
+
yield emit_ids(chunk)
|
| 299 |
+
if start + max_len >= len(ids):
|
| 300 |
+
break
|
| 301 |
+
|
| 302 |
+
def flush_pack() -> Iterator[dict]:
|
| 303 |
+
nonlocal pack
|
| 304 |
+
if len(pack) >= min_len:
|
| 305 |
+
yield emit_ids(pack)
|
| 306 |
+
pack = []
|
| 307 |
+
|
| 308 |
+
def append_stream(ids: list[int]) -> Iterator[dict]:
|
| 309 |
+
nonlocal pack
|
| 310 |
+
pack.extend(int(x) for x in ids)
|
| 311 |
+
while len(pack) >= max_len:
|
| 312 |
+
yield emit_ids(pack[:max_len])
|
| 313 |
+
pack = pack[max_len:]
|
| 314 |
+
|
| 315 |
+
for texts in _iter_parquet_text_batches(
|
| 316 |
+
path,
|
| 317 |
+
text_column=text_column,
|
| 318 |
+
row_start=row_start,
|
| 319 |
+
row_stop=row_stop,
|
| 320 |
+
batch_size=max(1, int(tokenize_batch_size)),
|
| 321 |
+
):
|
| 322 |
+
if resolved_detok:
|
| 323 |
+
texts = [detokenize_text(text, resolved_detok) for text in texts]
|
| 324 |
+
encoded = tokenizer.tokenizer.encode_batch(texts, add_special_tokens=add_special_tokens)
|
| 325 |
+
for enc in encoded:
|
| 326 |
+
ids = list(enc.ids)
|
| 327 |
+
if add_eos:
|
| 328 |
+
ids.append(tokenizer.eos_id)
|
| 329 |
+
if not ids:
|
| 330 |
+
continue
|
| 331 |
+
if packing_mode == "record_chunks":
|
| 332 |
+
yield from iter_record_chunks(ids)
|
| 333 |
+
elif packing_mode == "packed_records":
|
| 334 |
+
if len(ids) > max_len:
|
| 335 |
+
yield from flush_pack()
|
| 336 |
+
yield from iter_record_chunks(ids)
|
| 337 |
+
else:
|
| 338 |
+
if pack and len(pack) + len(ids) > max_len:
|
| 339 |
+
yield from flush_pack()
|
| 340 |
+
pack.extend(int(x) for x in ids)
|
| 341 |
+
if len(pack) >= max_len:
|
| 342 |
+
yield from flush_pack()
|
| 343 |
+
else:
|
| 344 |
+
yield from append_stream(ids)
|
| 345 |
+
if packing_mode in ("packed_records", "stream_chunks"):
|
| 346 |
+
yield from flush_pack()
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def _build_part(task: dict) -> dict:
|
| 350 |
+
from datasets import Dataset, Features, Sequence, Value, disable_progress_bars
|
| 351 |
+
|
| 352 |
+
disable_progress_bars()
|
| 353 |
+
|
| 354 |
+
part_dir = Path(task["part_dir"])
|
| 355 |
+
if part_dir.exists():
|
| 356 |
+
shutil.rmtree(part_dir)
|
| 357 |
+
features = Features(
|
| 358 |
+
{
|
| 359 |
+
"input_ids": Sequence(Value("int32")),
|
| 360 |
+
"sequence_length": Value("int64"),
|
| 361 |
+
}
|
| 362 |
+
)
|
| 363 |
+
ds = Dataset.from_generator(
|
| 364 |
+
_iter_part_examples,
|
| 365 |
+
gen_kwargs={
|
| 366 |
+
"spec": task["spec"],
|
| 367 |
+
"tokenizer_path": task["tokenizer_path"],
|
| 368 |
+
"text_column": task["text_column"],
|
| 369 |
+
"detokenizer": task["detokenizer"],
|
| 370 |
+
"max_len": task["max_len"],
|
| 371 |
+
"packing_mode": task["packing_mode"],
|
| 372 |
+
"min_len": task["min_len"],
|
| 373 |
+
"add_eos": task["add_eos"],
|
| 374 |
+
"add_special_tokens": task["add_special_tokens"],
|
| 375 |
+
"tokenize_batch_size": task["tokenize_batch_size"],
|
| 376 |
+
},
|
| 377 |
+
features=features,
|
| 378 |
+
cache_dir=task["cache_dir"] or None,
|
| 379 |
+
)
|
| 380 |
+
ds.save_to_disk(str(part_dir), max_shard_size=task["max_shard_size"])
|
| 381 |
+
lengths = ds["sequence_length"] if len(ds) else []
|
| 382 |
+
total_tokens = int(sum(int(x) for x in lengths))
|
| 383 |
+
if task["cache_dir"]:
|
| 384 |
+
shutil.rmtree(task["cache_dir"], ignore_errors=True)
|
| 385 |
+
return {
|
| 386 |
+
"part_dir": str(part_dir),
|
| 387 |
+
"num_examples": int(len(ds)),
|
| 388 |
+
"total_tokens": total_tokens,
|
| 389 |
+
"spec": task["spec"],
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def _part_is_complete(part_dir: Path) -> bool:
|
| 394 |
+
return (part_dir / "state.json").exists() and any(part_dir.glob("data-*.arrow"))
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def _summarize_part(part_dir: Path, spec: tuple[str, int, int | None]) -> dict:
|
| 398 |
+
from datasets import load_from_disk
|
| 399 |
+
|
| 400 |
+
ds = load_from_disk(str(part_dir))
|
| 401 |
+
lengths = ds["sequence_length"] if len(ds) else []
|
| 402 |
+
total_tokens = int(sum(int(x) for x in lengths))
|
| 403 |
+
return {
|
| 404 |
+
"part_dir": str(part_dir),
|
| 405 |
+
"num_examples": int(len(ds)),
|
| 406 |
+
"total_tokens": total_tokens,
|
| 407 |
+
"spec": spec,
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def _preload_datasets_for_fork() -> None:
|
| 412 |
+
# Importing datasets pulls in fsspec, which scans Python entry points.
|
| 413 |
+
# On this machine that scan can intermittently hit a corrupt/fragile zipped
|
| 414 |
+
# egg when many workers import at once. Preloading in the parent lets forked
|
| 415 |
+
# workers reuse sys.modules instead of racing through the entry point scan.
|
| 416 |
+
from datasets import Dataset, Features, Sequence, Value, disable_progress_bars, load_from_disk # noqa: F401
|
| 417 |
+
|
| 418 |
+
disable_progress_bars()
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def _parallel_build(args: argparse.Namespace) -> dict:
|
| 422 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 423 |
+
|
| 424 |
+
specs = _make_limited_specs(args)
|
| 425 |
+
if not specs:
|
| 426 |
+
raise RuntimeError("No input file specs found")
|
| 427 |
+
|
| 428 |
+
output_dir = Path(args.output_dir)
|
| 429 |
+
parts_root = output_dir / "parts"
|
| 430 |
+
parts_root.mkdir(parents=True, exist_ok=True)
|
| 431 |
+
|
| 432 |
+
tasks = []
|
| 433 |
+
part_results = []
|
| 434 |
+
for idx, spec in enumerate(specs):
|
| 435 |
+
part_dir = parts_root / f"part-{idx:05d}"
|
| 436 |
+
if args.resume_parts and _part_is_complete(part_dir):
|
| 437 |
+
part_results.append(_summarize_part(part_dir, spec))
|
| 438 |
+
continue
|
| 439 |
+
tasks.append(
|
| 440 |
+
{
|
| 441 |
+
"part_dir": str(part_dir),
|
| 442 |
+
"spec": spec,
|
| 443 |
+
"tokenizer_path": args.tokenizer_path,
|
| 444 |
+
"text_column": args.text_column,
|
| 445 |
+
"detokenizer": args.detokenizer,
|
| 446 |
+
"max_len": int(args.max_len),
|
| 447 |
+
"packing_mode": args.packing_mode,
|
| 448 |
+
"min_len": int(args.min_len),
|
| 449 |
+
"add_eos": bool(args.add_eos),
|
| 450 |
+
"add_special_tokens": bool(args.add_special_tokens),
|
| 451 |
+
"tokenize_batch_size": int(args.tokenize_batch_size),
|
| 452 |
+
"cache_dir": str(Path(args.cache_dir) / f"part-{idx:05d}") if args.cache_dir else "",
|
| 453 |
+
"max_shard_size": args.max_shard_size,
|
| 454 |
+
}
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
print(
|
| 458 |
+
f"[build] specs={len(specs)} existing={len(part_results)} "
|
| 459 |
+
f"todo={len(tasks)} num_proc={args.num_proc} output={output_dir}",
|
| 460 |
+
flush=True,
|
| 461 |
+
)
|
| 462 |
+
if tasks:
|
| 463 |
+
_preload_datasets_for_fork()
|
| 464 |
+
with ProcessPoolExecutor(max_workers=max(1, int(args.num_proc))) as pool:
|
| 465 |
+
futures = [pool.submit(_build_part, task) for task in tasks]
|
| 466 |
+
for done, fut in enumerate(as_completed(futures), start=1):
|
| 467 |
+
result = fut.result()
|
| 468 |
+
part_results.append(result)
|
| 469 |
+
print(
|
| 470 |
+
"[build] "
|
| 471 |
+
f"{done}/{len(futures)} {Path(result['part_dir']).name} "
|
| 472 |
+
f"examples={result['num_examples']} tokens={result['total_tokens']}",
|
| 473 |
+
flush=True,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
part_results.sort(key=lambda x: x["part_dir"])
|
| 477 |
+
total_examples = sum(int(x["num_examples"]) for x in part_results)
|
| 478 |
+
total_tokens = sum(int(x["total_tokens"]) for x in part_results)
|
| 479 |
+
meta = {
|
| 480 |
+
"builder": "build_owt_t5_elf_dataset.py",
|
| 481 |
+
"format": f"elf_unconditional_tokenized_{args.packing_mode}_multipart",
|
| 482 |
+
"data_path": args.data_path,
|
| 483 |
+
"tokenizer_path": args.tokenizer_path,
|
| 484 |
+
"text_column": args.text_column,
|
| 485 |
+
"openwebtext_split": args.openwebtext_split,
|
| 486 |
+
"openwebtext_valid_records": args.openwebtext_valid_records,
|
| 487 |
+
"max_len": args.max_len,
|
| 488 |
+
"packing_mode": args.packing_mode,
|
| 489 |
+
"max_records": args.max_records,
|
| 490 |
+
"min_len": args.min_len,
|
| 491 |
+
"add_eos": args.add_eos,
|
| 492 |
+
"add_special_tokens": args.add_special_tokens,
|
| 493 |
+
"num_parts": len(part_results),
|
| 494 |
+
"num_examples": int(total_examples),
|
| 495 |
+
"total_tokens": int(total_tokens),
|
| 496 |
+
"mean_length": float(total_tokens / total_examples) if total_examples else 0.0,
|
| 497 |
+
"parts": part_results,
|
| 498 |
+
}
|
| 499 |
+
(output_dir / "elf_multi_part_meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True), encoding="utf-8")
|
| 500 |
+
|
| 501 |
+
if args.merge_parts:
|
| 502 |
+
from datasets import concatenate_datasets, load_from_disk
|
| 503 |
+
|
| 504 |
+
merged_tmp = output_dir / "_merged_tmp"
|
| 505 |
+
if merged_tmp.exists():
|
| 506 |
+
shutil.rmtree(merged_tmp)
|
| 507 |
+
datasets = [load_from_disk(result["part_dir"]) for result in part_results if result["num_examples"] > 0]
|
| 508 |
+
merged = datasets[0] if len(datasets) == 1 else concatenate_datasets(datasets)
|
| 509 |
+
merged.save_to_disk(str(merged_tmp), max_shard_size=args.max_shard_size)
|
| 510 |
+
for child in list(output_dir.iterdir()):
|
| 511 |
+
if child.name in {"_merged_tmp", "parts"}:
|
| 512 |
+
continue
|
| 513 |
+
if child.is_dir():
|
| 514 |
+
shutil.rmtree(child)
|
| 515 |
+
else:
|
| 516 |
+
child.unlink()
|
| 517 |
+
for child in list(merged_tmp.iterdir()):
|
| 518 |
+
child.rename(output_dir / child.name)
|
| 519 |
+
merged_tmp.rmdir()
|
| 520 |
+
if not args.keep_parts:
|
| 521 |
+
shutil.rmtree(parts_root)
|
| 522 |
+
meta["format"] = f"elf_unconditional_tokenized_{args.packing_mode}"
|
| 523 |
+
(output_dir / "elf_build_meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True), encoding="utf-8")
|
| 524 |
+
|
| 525 |
+
return meta
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def main() -> None:
|
| 529 |
+
args = parse_args()
|
| 530 |
+
output_dir = Path(args.output_dir)
|
| 531 |
+
|
| 532 |
+
if args.stats_only:
|
| 533 |
+
print(json.dumps(_stats(args), indent=2, sort_keys=True))
|
| 534 |
+
return
|
| 535 |
+
|
| 536 |
+
if output_dir.exists():
|
| 537 |
+
if not args.overwrite:
|
| 538 |
+
if not args.resume_parts:
|
| 539 |
+
raise SystemExit(f"output_dir exists: {output_dir}; pass --overwrite to replace it")
|
| 540 |
+
elif not args.resume_parts:
|
| 541 |
+
shutil.rmtree(output_dir)
|
| 542 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 543 |
+
|
| 544 |
+
if args.num_proc > 1:
|
| 545 |
+
meta = _parallel_build(args)
|
| 546 |
+
print(json.dumps({k: v for k, v in meta.items() if k != "parts"}, indent=2, sort_keys=True))
|
| 547 |
+
return
|
| 548 |
+
|
| 549 |
+
from datasets import Dataset, Features, Sequence, Value
|
| 550 |
+
|
| 551 |
+
features = Features(
|
| 552 |
+
{
|
| 553 |
+
"input_ids": Sequence(Value("int32")),
|
| 554 |
+
"sequence_length": Value("int64"),
|
| 555 |
+
}
|
| 556 |
+
)
|
| 557 |
+
ds = Dataset.from_generator(
|
| 558 |
+
_iter_examples,
|
| 559 |
+
gen_kwargs=_gen_kwargs(args),
|
| 560 |
+
features=features,
|
| 561 |
+
cache_dir=args.cache_dir or None,
|
| 562 |
+
)
|
| 563 |
+
ds.save_to_disk(str(output_dir), max_shard_size=args.max_shard_size)
|
| 564 |
+
|
| 565 |
+
meta = {
|
| 566 |
+
"builder": "build_owt_t5_elf_dataset.py",
|
| 567 |
+
"format": f"elf_unconditional_tokenized_{args.packing_mode}",
|
| 568 |
+
"data_path": args.data_path,
|
| 569 |
+
"tokenizer_path": args.tokenizer_path,
|
| 570 |
+
"text_column": args.text_column,
|
| 571 |
+
"openwebtext_split": args.openwebtext_split,
|
| 572 |
+
"openwebtext_valid_records": args.openwebtext_valid_records,
|
| 573 |
+
"max_len": args.max_len,
|
| 574 |
+
"packing_mode": args.packing_mode,
|
| 575 |
+
"max_records": args.max_records,
|
| 576 |
+
"min_len": args.min_len,
|
| 577 |
+
"add_eos": args.add_eos,
|
| 578 |
+
"add_special_tokens": args.add_special_tokens,
|
| 579 |
+
"num_examples": int(len(ds)),
|
| 580 |
+
"columns": list(ds.column_names),
|
| 581 |
+
}
|
| 582 |
+
(output_dir / "elf_build_meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True), encoding="utf-8")
|
| 583 |
+
print(json.dumps(meta, indent=2, sort_keys=True))
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
if __name__ == "__main__":
|
| 587 |
+
main()
|
LTA_openwebtext_dualt/scripts/eval_dirichlet_latest_key3_state_20260508.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import importlib.util
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
BASE = Path(__file__).with_name("eval_c1024_decode_sweep_20260507.py")
|
| 10 |
+
spec = importlib.util.spec_from_file_location("eval_c1024_decode_sweep_20260507", BASE)
|
| 11 |
+
if spec is None or spec.loader is None:
|
| 12 |
+
raise RuntimeError(f"cannot import {BASE}")
|
| 13 |
+
base = importlib.util.module_from_spec(spec)
|
| 14 |
+
sys.modules[spec.name] = base
|
| 15 |
+
spec.loader.exec_module(base)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def key_configs() -> list[base.DecodeConfig]:
|
| 19 |
+
return [
|
| 20 |
+
base.DecodeConfig(
|
| 21 |
+
"match_post_sem1_state_c16_t1p3",
|
| 22 |
+
"post",
|
| 23 |
+
1.0,
|
| 24 |
+
1.0,
|
| 25 |
+
"state",
|
| 26 |
+
endpoint_temp=1.3,
|
| 27 |
+
concentration_max=16.0,
|
| 28 |
+
),
|
| 29 |
+
base.DecodeConfig(
|
| 30 |
+
"match_post_sem1_state_c64_t1p3",
|
| 31 |
+
"post",
|
| 32 |
+
1.0,
|
| 33 |
+
1.0,
|
| 34 |
+
"state",
|
| 35 |
+
endpoint_temp=1.3,
|
| 36 |
+
concentration_max=64.0,
|
| 37 |
+
),
|
| 38 |
+
base.DecodeConfig(
|
| 39 |
+
"match_post_sem1_state_c1024_t1p3",
|
| 40 |
+
"post",
|
| 41 |
+
1.0,
|
| 42 |
+
1.0,
|
| 43 |
+
"state",
|
| 44 |
+
endpoint_temp=1.3,
|
| 45 |
+
concentration_max=1024.0,
|
| 46 |
+
),
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
base.default_configs = key_configs
|
| 51 |
+
base.main()
|
LTA_openwebtext_dualt/scripts/infer_lta_owt_t5_len128_uniform10k_then_lognsr_latest.sh
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
|
| 7 |
+
export PYTHONUNBUFFERED=1
|
| 8 |
+
export TOKENIZERS_PARALLELISM=false
|
| 9 |
+
|
| 10 |
+
RUN_PREFIX="${RUN_PREFIX:-lta_owt_t5_len128_uniform10k_then_lognsr}"
|
| 11 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
|
| 12 |
+
SCORER="${SCORER:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-large-standard}"
|
| 13 |
+
|
| 14 |
+
N_SAMPLES="${N_SAMPLES:-8}"
|
| 15 |
+
DECODE_BATCH="${DECODE_BATCH:-4}"
|
| 16 |
+
SCORE_BATCH="${SCORE_BATCH:-4}"
|
| 17 |
+
MAX_LEN="${MAX_LEN:-128}"
|
| 18 |
+
STEPS="${STEPS:-1024}"
|
| 19 |
+
ENDPOINT_TEMPS="${ENDPOINT_TEMPS:-1.0,1.15,1.30,1.45}"
|
| 20 |
+
|
| 21 |
+
DECODE_RULE="${DECODE_RULE:-dirichlet_resample}"
|
| 22 |
+
MODEL_T_MODE="${MODEL_T_MODE:-post}"
|
| 23 |
+
TIME_SCHEDULE="${TIME_SCHEDULE:-lognsr_gumbel}"
|
| 24 |
+
TIME_GUMBEL_LOC="${TIME_GUMBEL_LOC:-2.2}"
|
| 25 |
+
TIME_GUMBEL_SCALE="${TIME_GUMBEL_SCALE:-0.8}"
|
| 26 |
+
CONCENTRATION_MIN="${CONCENTRATION_MIN:-1}"
|
| 27 |
+
CONCENTRATION_MAX="${CONCENTRATION_MAX:-64}"
|
| 28 |
+
NOISE_INIT="${NOISE_INIT:-dirichlet}"
|
| 29 |
+
FINAL_FROM="${FINAL_FROM:-state}"
|
| 30 |
+
FINAL_SAMPLE_MODE="${FINAL_SAMPLE_MODE:-argmax}"
|
| 31 |
+
|
| 32 |
+
pick_run() {
|
| 33 |
+
local suffix="$1"
|
| 34 |
+
find runs -maxdepth 1 -type d -name "${RUN_PREFIX}*${suffix}" -printf "%T@ %p\n" 2>/dev/null \
|
| 35 |
+
| sort -nr \
|
| 36 |
+
| head -n 1 \
|
| 37 |
+
| cut -d' ' -f2-
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
RUN_DIR="${RUN_DIR:-}"
|
| 41 |
+
if [[ -z "${RUN_DIR}" ]]; then
|
| 42 |
+
RUN_DIR="$(pick_run "_resume_lognsr_sde_rollin")"
|
| 43 |
+
fi
|
| 44 |
+
if [[ -z "${RUN_DIR}" ]]; then
|
| 45 |
+
RUN_DIR="$(pick_run "_warmup_uniform_norollin")"
|
| 46 |
+
fi
|
| 47 |
+
if [[ -z "${RUN_DIR}" || ! -d "${RUN_DIR}" ]]; then
|
| 48 |
+
echo "[infer] could not find run dir for prefix=${RUN_PREFIX}" >&2
|
| 49 |
+
exit 1
|
| 50 |
+
fi
|
| 51 |
+
|
| 52 |
+
CKPT="${CKPT:-}"
|
| 53 |
+
if [[ -z "${CKPT}" ]]; then
|
| 54 |
+
CKPT="$(ls -1 "${RUN_DIR}"/step_*.pt 2>/dev/null | sort | tail -n 1 || true)"
|
| 55 |
+
fi
|
| 56 |
+
if [[ -z "${CKPT}" || ! -f "${CKPT}" ]]; then
|
| 57 |
+
echo "[infer] could not find checkpoint under ${RUN_DIR}" >&2
|
| 58 |
+
exit 1
|
| 59 |
+
fi
|
| 60 |
+
|
| 61 |
+
RUN_BASENAME="$(basename "${RUN_DIR}")"
|
| 62 |
+
CKPT_BASENAME="$(basename "${CKPT}" .pt)"
|
| 63 |
+
OUT_DIR="${OUT_DIR:-docs/lta_samples/metrics_20260519/${RUN_BASENAME}_${CKPT_BASENAME}_len128_lm1bgood_sdeish_n${N_SAMPLES}}"
|
| 64 |
+
OUT_JSONL="${OUT_DIR}/summary.jsonl"
|
| 65 |
+
mkdir -p "${OUT_DIR}"
|
| 66 |
+
|
| 67 |
+
echo "[infer] run=${RUN_DIR}"
|
| 68 |
+
echo "[infer] ckpt=${CKPT}"
|
| 69 |
+
echo "[infer] out=${OUT_JSONL}"
|
| 70 |
+
echo "[infer] decode_rule=${DECODE_RULE} steps=${STEPS} cmax=${CONCENTRATION_MAX} model_t=${MODEL_T_MODE} temps=${ENDPOINT_TEMPS}"
|
| 71 |
+
|
| 72 |
+
python scripts/standard_genppl_entropy_latest_decode.py \
|
| 73 |
+
--checkpoint "${CKPT}" \
|
| 74 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 75 |
+
--scorer "${SCORER}" \
|
| 76 |
+
--output "${OUT_JSONL}" \
|
| 77 |
+
--max_len "${MAX_LEN}" \
|
| 78 |
+
--n_samples "${N_SAMPLES}" \
|
| 79 |
+
--decode_batch "${DECODE_BATCH}" \
|
| 80 |
+
--score_batch "${SCORE_BATCH}" \
|
| 81 |
+
--score_max_length "${MAX_LEN}" \
|
| 82 |
+
--steps "${STEPS}" \
|
| 83 |
+
--model_t_mode "${MODEL_T_MODE}" \
|
| 84 |
+
--decode_time_schedule "${TIME_SCHEDULE}" \
|
| 85 |
+
--decode_time_gumbel_loc "${TIME_GUMBEL_LOC}" \
|
| 86 |
+
--decode_time_gumbel_scale "${TIME_GUMBEL_SCALE}" \
|
| 87 |
+
--decode_rule "${DECODE_RULE}" \
|
| 88 |
+
--concentration_min "${CONCENTRATION_MIN}" \
|
| 89 |
+
--concentration_max "${CONCENTRATION_MAX}" \
|
| 90 |
+
--noise_init "${NOISE_INIT}" \
|
| 91 |
+
--endpoint_temps "${ENDPOINT_TEMPS}" \
|
| 92 |
+
--final_from "${FINAL_FROM}" \
|
| 93 |
+
--final_sample_mode "${FINAL_SAMPLE_MODE}" \
|
| 94 |
+
--save_samples "${N_SAMPLES}"
|
| 95 |
+
|
| 96 |
+
echo "[infer] summaries:"
|
| 97 |
+
python - "${OUT_JSONL}" <<'PY'
|
| 98 |
+
import json, sys
|
| 99 |
+
path = sys.argv[1]
|
| 100 |
+
with open(path, encoding="utf-8") as f:
|
| 101 |
+
for line in f:
|
| 102 |
+
row = json.loads(line)
|
| 103 |
+
if row.get("type") != "summary":
|
| 104 |
+
continue
|
| 105 |
+
d = row["decode"]
|
| 106 |
+
stripped = row.get("stripped_genppl", {})
|
| 107 |
+
div = row.get("diversity", {})
|
| 108 |
+
print(
|
| 109 |
+
f"temp={d['endpoint_temp']:.2f} final={d['final_from']} "
|
| 110 |
+
f"ppl={stripped.get('ppl')} entropy={div.get('sample_entropy')} "
|
| 111 |
+
f"top_mass={div.get('top_token_mass')}"
|
| 112 |
+
)
|
| 113 |
+
PY
|
LTA_openwebtext_dualt/scripts/launch_lta_lm1b_categorical_fullvocab_c1024_fullycoupled_8gpu_small_1m.sh
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 6 |
+
export TOKENIZERS_PARALLELISM=false
|
| 7 |
+
export PYTHONUNBUFFERED=1
|
| 8 |
+
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
|
| 9 |
+
export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
|
| 10 |
+
|
| 11 |
+
# Fully-coupled t ablation:
|
| 12 |
+
# model_t == support/Dirichlet t == semantic endpoint t
|
| 13 |
+
RUN_NAME="${RUN_NAME:-lta_lm1b_dirichlet_categorical_fullvocab_c1024_fullycoupled_flmpack_onehot_hardce_ddit_small_len128_gbs512_8gpu_1m_nw0}"
|
| 14 |
+
DATA_PATH="${DATA_PATH:-data/lm1b_train_parquet}"
|
| 15 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/workspace/imagenet_handoff_20260327/nlp_dts_light/assets/distilbert-base-uncased/tokenizer.json}"
|
| 16 |
+
TEXT_COLUMN="${TEXT_COLUMN:-}"
|
| 17 |
+
OPENWEBTEXT_SPLIT="${OPENWEBTEXT_SPLIT:-all}"
|
| 18 |
+
SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
|
| 19 |
+
LOG_FILE="${LOG_FILE:-logs/${RUN_NAME}.log}"
|
| 20 |
+
|
| 21 |
+
NNODES="${NNODES:-1}"
|
| 22 |
+
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
| 23 |
+
NODE_RANK="${NODE_RANK:-0}"
|
| 24 |
+
MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
|
| 25 |
+
MASTER_PORT="${MASTER_PORT:-29631}"
|
| 26 |
+
|
| 27 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 28 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-64}"
|
| 29 |
+
TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
|
| 30 |
+
WARMUP_STEPS="${WARMUP_STEPS:-2500}"
|
| 31 |
+
MAX_LEN="${MAX_LEN:-128}"
|
| 32 |
+
WRAP_MODE="${WRAP_MODE:-stream}"
|
| 33 |
+
WRAP_RECORD_BUFFER_SIZE="${WRAP_RECORD_BUFFER_SIZE:-200}"
|
| 34 |
+
NUM_WORKERS="${NUM_WORKERS:-0}"
|
| 35 |
+
LOG_EVERY="${LOG_EVERY:-100}"
|
| 36 |
+
SAVE_EVERY="${SAVE_EVERY:-20000}"
|
| 37 |
+
LATEST_EVERY="${LATEST_EVERY:-1000}"
|
| 38 |
+
EVAL_EVERY="${EVAL_EVERY:-0}"
|
| 39 |
+
RESUME_PATH="${RESUME_PATH:-}"
|
| 40 |
+
ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
|
| 41 |
+
ENABLE_TORCH_COMPILE="${ENABLE_TORCH_COMPILE:-0}"
|
| 42 |
+
FORCE_DISABLE_TORCH_COMPILE="${FORCE_DISABLE_TORCH_COMPILE:-1}"
|
| 43 |
+
|
| 44 |
+
if [[ "${FORCE_DISABLE_TORCH_COMPILE}" == "1" ]]; then
|
| 45 |
+
ENABLE_TORCH_COMPILE=0
|
| 46 |
+
fi
|
| 47 |
+
if [[ "${DATA_PATH}" == *"lm1b_train_parquet"* && "${NUM_WORKERS}" != "0" ]]; then
|
| 48 |
+
echo "LM1B has only 9 parquet shards; forcing NUM_WORKERS=0 to avoid empty DDP dataloader shards." >&2
|
| 49 |
+
NUM_WORKERS=0
|
| 50 |
+
fi
|
| 51 |
+
|
| 52 |
+
COMPILE_ARGS=()
|
| 53 |
+
if [[ "${ENABLE_TORCH_COMPILE}" == "1" ]]; then
|
| 54 |
+
COMPILE_ARGS+=(--torch_compile --compile_mode reduce-overhead)
|
| 55 |
+
fi
|
| 56 |
+
RESUME_ARGS=()
|
| 57 |
+
if [[ -n "${RESUME_PATH}" ]]; then
|
| 58 |
+
RESUME_ARGS+=(--resume_path "${RESUME_PATH}")
|
| 59 |
+
fi
|
| 60 |
+
TEXT_COLUMN_ARGS=()
|
| 61 |
+
if [[ -n "${TEXT_COLUMN}" ]]; then
|
| 62 |
+
TEXT_COLUMN_ARGS+=(--text_column "${TEXT_COLUMN}")
|
| 63 |
+
fi
|
| 64 |
+
|
| 65 |
+
if [[ -f "${SAVE_DIR}/args.json" && -z "${RESUME_PATH}" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
|
| 66 |
+
echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
|
| 67 |
+
echo "Use a new RUN_NAME/SAVE_DIR, set RESUME_PATH to resume, or set ALLOW_EXISTING_SAVE_DIR=1 intentionally." >&2
|
| 68 |
+
exit 2
|
| 69 |
+
fi
|
| 70 |
+
|
| 71 |
+
mkdir -p logs runs "${SAVE_DIR}"
|
| 72 |
+
echo "[launch] method=categorical_fullvocab_c1024_fullycoupled host=$(hostname) time=$(date -Iseconds)"
|
| 73 |
+
echo "[launch] cwd=$(pwd)"
|
| 74 |
+
echo "[launch] run_name=${RUN_NAME}"
|
| 75 |
+
echo "[launch] save_dir=${SAVE_DIR}"
|
| 76 |
+
echo "[launch] log_file=${LOG_FILE}"
|
| 77 |
+
|
| 78 |
+
python -m torch.distributed.run \
|
| 79 |
+
--nnodes="${NNODES}" \
|
| 80 |
+
--nproc_per_node="${NPROC_PER_NODE}" \
|
| 81 |
+
--node_rank="${NODE_RANK}" \
|
| 82 |
+
--master_addr="${MASTER_ADDR}" \
|
| 83 |
+
--master_port="${MASTER_PORT}" \
|
| 84 |
+
train.py \
|
| 85 |
+
--data_path "${DATA_PATH}" \
|
| 86 |
+
"${TEXT_COLUMN_ARGS[@]}" \
|
| 87 |
+
--openwebtext_split "${OPENWEBTEXT_SPLIT}" \
|
| 88 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 89 |
+
--save_dir "${SAVE_DIR}" \
|
| 90 |
+
--wrap \
|
| 91 |
+
--wrap_mode "${WRAP_MODE}" \
|
| 92 |
+
--wrap_record_buffer_size "${WRAP_RECORD_BUFFER_SIZE}" \
|
| 93 |
+
--max_len "${MAX_LEN}" \
|
| 94 |
+
--batch_size "${PER_GPU_BATCH_SIZE}" \
|
| 95 |
+
--num_workers "${NUM_WORKERS}" \
|
| 96 |
+
--global_batch_size "${GLOBAL_BATCH_SIZE}" \
|
| 97 |
+
--total_steps "${TOTAL_STEPS}" \
|
| 98 |
+
--log_every "${LOG_EVERY}" \
|
| 99 |
+
--eval_every "${EVAL_EVERY}" \
|
| 100 |
+
--save_every "${SAVE_EVERY}" \
|
| 101 |
+
--latest_every "${LATEST_EVERY}" \
|
| 102 |
+
--lr 3e-4 \
|
| 103 |
+
--weight_decay 0 \
|
| 104 |
+
--adam_beta1 0.9 \
|
| 105 |
+
--adam_beta2 0.999 \
|
| 106 |
+
--adam_eps 1e-8 \
|
| 107 |
+
--warmup_steps "${WARMUP_STEPS}" \
|
| 108 |
+
--lr_schedule constant_warmup \
|
| 109 |
+
--grad_clip 1.0 \
|
| 110 |
+
--seed 123 \
|
| 111 |
+
--d_model 768 \
|
| 112 |
+
--cond_dim 128 \
|
| 113 |
+
--n_layers 12 \
|
| 114 |
+
--n_heads 12 \
|
| 115 |
+
--dim_ff 3072 \
|
| 116 |
+
--dropout 0.1 \
|
| 117 |
+
--model_type ddit \
|
| 118 |
+
--state_format prob \
|
| 119 |
+
--bridge dirichlet \
|
| 120 |
+
--target_loss hard_ce \
|
| 121 |
+
--target_prob 1.0 \
|
| 122 |
+
--min_t 0.0 \
|
| 123 |
+
--max_t 1.0 \
|
| 124 |
+
--dual_t \
|
| 125 |
+
--corrupt_t_mode same \
|
| 126 |
+
--corrupt_min_t 0.0 \
|
| 127 |
+
--corrupt_max_t 1.0 \
|
| 128 |
+
--min_mask_ratio 0.1 \
|
| 129 |
+
--max_mask_ratio 1.0 \
|
| 130 |
+
--wrong_token_replace_prob 1.0 \
|
| 131 |
+
--wrong_token_schedule linear_t \
|
| 132 |
+
--wrong_token_exp_k 1.0 \
|
| 133 |
+
--dirichlet_concentration_min 1.0 \
|
| 134 |
+
--dirichlet_concentration_max 1024.0 \
|
| 135 |
+
--dirichlet_endpoint_mode categorical_dual_t \
|
| 136 |
+
--dirichlet_semantic_t_mode same \
|
| 137 |
+
--dirichlet_semantic_t_value 0.0 \
|
| 138 |
+
--categorical_wrong_from_full_vocab \
|
| 139 |
+
--simplex_bridge_sampler dirichlet \
|
| 140 |
+
--eps 1e-8 \
|
| 141 |
+
--infer_steps 128 \
|
| 142 |
+
--decode_damping 1.0 \
|
| 143 |
+
--max_gamma 1.0 \
|
| 144 |
+
--decode_solver flowmap \
|
| 145 |
+
--noise_init logistic_normal \
|
| 146 |
+
--bridge_noise_init logistic_normal \
|
| 147 |
+
--noise_sigma -1 \
|
| 148 |
+
"${RESUME_ARGS[@]}" \
|
| 149 |
+
"${COMPILE_ARGS[@]}" \
|
| 150 |
+
--bf16 2>&1 | tee -a "${LOG_FILE}"
|
LTA_openwebtext_dualt/scripts/launch_lta_lm1b_categorical_fullvocab_c16_dualt_4gpu_small_1m.sh
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 6 |
+
export TOKENIZERS_PARALLELISM=false
|
| 7 |
+
export PYTHONUNBUFFERED=1
|
| 8 |
+
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
|
| 9 |
+
export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
|
| 10 |
+
|
| 11 |
+
# C=16 categorical dual-t LM1B, full-vocab wrong-token endpoint.
|
| 12 |
+
# This is the 4-GPU counterpart of the 8-GPU full-vocab run; global batch stays 512.
|
| 13 |
+
|
| 14 |
+
C_MAX="${C_MAX:-16.0}"
|
| 15 |
+
C_TAG="${C_TAG:-c${C_MAX//./p}}"
|
| 16 |
+
RUN_NAME="${RUN_NAME:-lta_lm1b_dirichlet_categorical_fullvocab_${C_TAG}_dualt_flmpack_onehot_hardce_ddit_small_len128_gbs512_4gpu_1m_nw0}"
|
| 17 |
+
DATA_PATH="${DATA_PATH:-data/lm1b_train_parquet}"
|
| 18 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/workspace/imagenet_handoff_20260327/nlp_dts_light/assets/distilbert-base-uncased/tokenizer.json}"
|
| 19 |
+
DETOKENIZER="${DETOKENIZER:-auto}"
|
| 20 |
+
TEXT_COLUMN="${TEXT_COLUMN:-}"
|
| 21 |
+
OPENWEBTEXT_SPLIT="${OPENWEBTEXT_SPLIT:-all}"
|
| 22 |
+
SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
|
| 23 |
+
LOG_FILE="${LOG_FILE:-logs/${RUN_NAME}.log}"
|
| 24 |
+
|
| 25 |
+
NNODES="${NNODES:-1}"
|
| 26 |
+
NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
|
| 27 |
+
NODE_RANK="${NODE_RANK:-0}"
|
| 28 |
+
MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
|
| 29 |
+
MASTER_PORT="${MASTER_PORT:-29641}"
|
| 30 |
+
|
| 31 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 32 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-64}"
|
| 33 |
+
TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
|
| 34 |
+
WARMUP_STEPS="${WARMUP_STEPS:-2500}"
|
| 35 |
+
MAX_LEN="${MAX_LEN:-128}"
|
| 36 |
+
WRAP_MODE="${WRAP_MODE:-stream}"
|
| 37 |
+
WRAP_RECORD_BUFFER_SIZE="${WRAP_RECORD_BUFFER_SIZE:-200}"
|
| 38 |
+
NUM_WORKERS="${NUM_WORKERS:-0}"
|
| 39 |
+
LOG_EVERY="${LOG_EVERY:-100}"
|
| 40 |
+
SAVE_EVERY="${SAVE_EVERY:-20000}"
|
| 41 |
+
LATEST_EVERY="${LATEST_EVERY:-1000}"
|
| 42 |
+
EVAL_EVERY="${EVAL_EVERY:-0}"
|
| 43 |
+
RESUME_PATH="${RESUME_PATH:-}"
|
| 44 |
+
ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
|
| 45 |
+
ENABLE_TORCH_COMPILE="${ENABLE_TORCH_COMPILE:-0}"
|
| 46 |
+
FORCE_DISABLE_TORCH_COMPILE="${FORCE_DISABLE_TORCH_COMPILE:-1}"
|
| 47 |
+
|
| 48 |
+
if [[ "${FORCE_DISABLE_TORCH_COMPILE}" == "1" ]]; then
|
| 49 |
+
ENABLE_TORCH_COMPILE=0
|
| 50 |
+
fi
|
| 51 |
+
if [[ "${DATA_PATH}" == *"lm1b_train_parquet"* && "${NUM_WORKERS}" != "0" ]]; then
|
| 52 |
+
echo "LM1B has only 9 parquet shards; forcing NUM_WORKERS=0 to avoid empty DDP dataloader shards." >&2
|
| 53 |
+
NUM_WORKERS=0
|
| 54 |
+
fi
|
| 55 |
+
|
| 56 |
+
COMPILE_ARGS=()
|
| 57 |
+
if [[ "${ENABLE_TORCH_COMPILE}" == "1" ]]; then
|
| 58 |
+
COMPILE_ARGS+=(--torch_compile --compile_mode reduce-overhead)
|
| 59 |
+
fi
|
| 60 |
+
RESUME_ARGS=()
|
| 61 |
+
if [[ -n "${RESUME_PATH}" ]]; then
|
| 62 |
+
RESUME_ARGS+=(--resume_path "${RESUME_PATH}")
|
| 63 |
+
fi
|
| 64 |
+
TEXT_COLUMN_ARGS=()
|
| 65 |
+
if [[ -n "${TEXT_COLUMN}" ]]; then
|
| 66 |
+
TEXT_COLUMN_ARGS+=(--text_column "${TEXT_COLUMN}")
|
| 67 |
+
fi
|
| 68 |
+
|
| 69 |
+
if [[ -f "${SAVE_DIR}/args.json" && -z "${RESUME_PATH}" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
|
| 70 |
+
echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
|
| 71 |
+
echo "Use a new RUN_NAME/SAVE_DIR, set RESUME_PATH to resume, or set ALLOW_EXISTING_SAVE_DIR=1 intentionally." >&2
|
| 72 |
+
exit 2
|
| 73 |
+
fi
|
| 74 |
+
|
| 75 |
+
mkdir -p logs runs "${SAVE_DIR}"
|
| 76 |
+
echo "[launch] method=categorical_fullvocab C_MAX=${C_MAX} host=$(hostname) time=$(date -Iseconds)"
|
| 77 |
+
echo "[launch] cwd=$(pwd)"
|
| 78 |
+
echo "[launch] run_name=${RUN_NAME}"
|
| 79 |
+
echo "[launch] save_dir=${SAVE_DIR}"
|
| 80 |
+
echo "[launch] log_file=${LOG_FILE}"
|
| 81 |
+
echo "[launch] nproc_per_node=${NPROC_PER_NODE} global_batch_size=${GLOBAL_BATCH_SIZE} per_gpu_batch_size=${PER_GPU_BATCH_SIZE}"
|
| 82 |
+
|
| 83 |
+
python -m torch.distributed.run \
|
| 84 |
+
--nnodes="${NNODES}" \
|
| 85 |
+
--nproc_per_node="${NPROC_PER_NODE}" \
|
| 86 |
+
--node_rank="${NODE_RANK}" \
|
| 87 |
+
--master_addr="${MASTER_ADDR}" \
|
| 88 |
+
--master_port="${MASTER_PORT}" \
|
| 89 |
+
train.py \
|
| 90 |
+
--data_path "${DATA_PATH}" \
|
| 91 |
+
"${TEXT_COLUMN_ARGS[@]}" \
|
| 92 |
+
--openwebtext_split "${OPENWEBTEXT_SPLIT}" \
|
| 93 |
+
--detokenizer "${DETOKENIZER}" \
|
| 94 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 95 |
+
--save_dir "${SAVE_DIR}" \
|
| 96 |
+
--wrap \
|
| 97 |
+
--wrap_mode "${WRAP_MODE}" \
|
| 98 |
+
--wrap_record_buffer_size "${WRAP_RECORD_BUFFER_SIZE}" \
|
| 99 |
+
--max_len "${MAX_LEN}" \
|
| 100 |
+
--batch_size "${PER_GPU_BATCH_SIZE}" \
|
| 101 |
+
--num_workers "${NUM_WORKERS}" \
|
| 102 |
+
--global_batch_size "${GLOBAL_BATCH_SIZE}" \
|
| 103 |
+
--total_steps "${TOTAL_STEPS}" \
|
| 104 |
+
--log_every "${LOG_EVERY}" \
|
| 105 |
+
--eval_every "${EVAL_EVERY}" \
|
| 106 |
+
--save_every "${SAVE_EVERY}" \
|
| 107 |
+
--latest_every "${LATEST_EVERY}" \
|
| 108 |
+
--lr 3e-4 \
|
| 109 |
+
--weight_decay 0 \
|
| 110 |
+
--adam_beta1 0.9 \
|
| 111 |
+
--adam_beta2 0.999 \
|
| 112 |
+
--adam_eps 1e-8 \
|
| 113 |
+
--warmup_steps "${WARMUP_STEPS}" \
|
| 114 |
+
--lr_schedule constant_warmup \
|
| 115 |
+
--grad_clip 1.0 \
|
| 116 |
+
--seed 123 \
|
| 117 |
+
--d_model 768 \
|
| 118 |
+
--cond_dim 128 \
|
| 119 |
+
--n_layers 12 \
|
| 120 |
+
--n_heads 12 \
|
| 121 |
+
--dim_ff 3072 \
|
| 122 |
+
--dropout 0.1 \
|
| 123 |
+
--model_type ddit \
|
| 124 |
+
--state_format prob \
|
| 125 |
+
--bridge dirichlet \
|
| 126 |
+
--target_loss hard_ce \
|
| 127 |
+
--target_prob 1.0 \
|
| 128 |
+
--min_t 0.0 \
|
| 129 |
+
--max_t 1.0 \
|
| 130 |
+
--dual_t \
|
| 131 |
+
--corrupt_t_mode independent \
|
| 132 |
+
--corrupt_min_t 0.0 \
|
| 133 |
+
--corrupt_max_t 1.0 \
|
| 134 |
+
--min_mask_ratio 0.1 \
|
| 135 |
+
--max_mask_ratio 1.0 \
|
| 136 |
+
--wrong_token_replace_prob 1.0 \
|
| 137 |
+
--wrong_token_schedule linear_t \
|
| 138 |
+
--wrong_token_exp_k 1.0 \
|
| 139 |
+
--dirichlet_concentration_min 1.0 \
|
| 140 |
+
--dirichlet_concentration_max "${C_MAX}" \
|
| 141 |
+
--dirichlet_endpoint_mode categorical_dual_t \
|
| 142 |
+
--dirichlet_semantic_t_mode independent \
|
| 143 |
+
--dirichlet_semantic_t_value 0.0 \
|
| 144 |
+
--categorical_wrong_from_full_vocab \
|
| 145 |
+
--eps 1e-8 \
|
| 146 |
+
--infer_steps 128 \
|
| 147 |
+
--decode_damping 1.0 \
|
| 148 |
+
--max_gamma 1.0 \
|
| 149 |
+
--decode_solver flowmap \
|
| 150 |
+
--noise_init logistic_normal \
|
| 151 |
+
--bridge_noise_init logistic_normal \
|
| 152 |
+
--noise_sigma -1 \
|
| 153 |
+
"${RESUME_ARGS[@]}" \
|
| 154 |
+
"${COMPILE_ARGS[@]}" \
|
| 155 |
+
--bf16 2>&1 | tee -a "${LOG_FILE}"
|
LTA_openwebtext_dualt/scripts/launch_lta_owt_c1024_fullycoupled_8gpu_len1024_gpt2_cached_chunks_1m.sh
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
# Explicit cached-chunk OWT/GPT-2 run.
|
| 7 |
+
# Uses the already-built cache:
|
| 8 |
+
# openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k
|
| 9 |
+
#
|
| 10 |
+
# Data processing:
|
| 11 |
+
# tokenize records with GPT-2 tokenizer
|
| 12 |
+
# append GPT-2 EOT after each record
|
| 13 |
+
# concatenate stream
|
| 14 |
+
# split into payload_len=1022
|
| 15 |
+
# wrap as [EOT] + payload + [EOT]
|
| 16 |
+
# train from fixed memmap chunks with DistributedSampler shuffle
|
| 17 |
+
|
| 18 |
+
export OWT_CACHED_CHUNKS=1
|
| 19 |
+
export OWT_CHUNK_CACHE_DIR="${OWT_CHUNK_CACHE_DIR:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k}"
|
| 20 |
+
# Default to reusing the prebuilt cache; set OWT_CHUNK_CACHE_REBUILD=1 only when
|
| 21 |
+
# intentionally refreshing or repairing the cached chunk pool.
|
| 22 |
+
export OWT_CHUNK_CACHE_REBUILD="${OWT_CHUNK_CACHE_REBUILD:-0}"
|
| 23 |
+
|
| 24 |
+
export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
| 25 |
+
export PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
|
| 26 |
+
export GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 27 |
+
export TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
|
| 28 |
+
export WARMUP_STEPS="${WARMUP_STEPS:-2000}"
|
| 29 |
+
export MAX_LEN="${MAX_LEN:-1024}"
|
| 30 |
+
export NUM_WORKERS="${NUM_WORKERS:-4}"
|
| 31 |
+
export DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-2}"
|
| 32 |
+
export LOG_EVERY="${LOG_EVERY:-100}"
|
| 33 |
+
export SAVE_EVERY="${SAVE_EVERY:-20000}"
|
| 34 |
+
export LATEST_EVERY="${LATEST_EVERY:-1000}"
|
| 35 |
+
export EVAL_EVERY="${EVAL_EVERY:-0}"
|
| 36 |
+
export ENABLE_TORCH_COMPILE="${ENABLE_TORCH_COMPILE:-0}"
|
| 37 |
+
export ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
|
| 38 |
+
export OPTIMIZER="${OPTIMIZER:-adamw}"
|
| 39 |
+
export MUON_MOMENTUM="${MUON_MOMENTUM:-0.95}"
|
| 40 |
+
export MUON_NS_STEPS="${MUON_NS_STEPS:-5}"
|
| 41 |
+
export MUON_UPDATE_SCALE="${MUON_UPDATE_SCALE:-1.0}"
|
| 42 |
+
export EMA_DECAY="${EMA_DECAY:-0.0}"
|
| 43 |
+
export EMA_START_STEP="${EMA_START_STEP:-0}"
|
| 44 |
+
export ALLOW_TF32="${ALLOW_TF32:-1}"
|
| 45 |
+
export ACTIVATION_CHECKPOINTING="${ACTIVATION_CHECKPOINTING:-0}"
|
| 46 |
+
export ACTIVATION_CHECKPOINT_INTERVAL="${ACTIVATION_CHECKPOINT_INTERVAL:-1}"
|
| 47 |
+
export DDP_STATIC_GRAPH="${DDP_STATIC_GRAPH:-0}"
|
| 48 |
+
export DDP_GRADIENT_AS_BUCKET_VIEW="${DDP_GRADIENT_AS_BUCKET_VIEW:-1}"
|
| 49 |
+
export BLOCKING_DATA_TRANSFER="${BLOCKING_DATA_TRANSFER:-0}"
|
| 50 |
+
export FULL_TRAIN_STATS="${FULL_TRAIN_STATS:-0}"
|
| 51 |
+
|
| 52 |
+
export DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext}"
|
| 53 |
+
export TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
|
| 54 |
+
export TEXT_COLUMN="${TEXT_COLUMN:-text}"
|
| 55 |
+
export OPENWEBTEXT_SPLIT="${OPENWEBTEXT_SPLIT:-train_minus_100k}"
|
| 56 |
+
export DETOKENIZER="${DETOKENIZER:-auto}"
|
| 57 |
+
|
| 58 |
+
export RUN_NAME="${RUN_NAME:-lta_owt_dirichlet_categorical_fullvocab_c1024_fullycoupled_gpt2_cached_chunks_len1024_gbs${GLOBAL_BATCH_SIZE}_${NPROC_PER_NODE}gpu_1m_nw${NUM_WORKERS}}"
|
| 59 |
+
|
| 60 |
+
bash scripts/launch_lta_owt_categorical_fullvocab_c1024_fullycoupled_8gpu_small_1m.sh
|
LTA_openwebtext_dualt/scripts/launch_lta_owt_compact_gpt2bpe_v8192_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
# 8k compact GPT2-BPE variant of the v2048 fully-coupled mask=1 baseline.
|
| 7 |
+
# Keep the actual training recipe centralized in the v2048 script; this wrapper
|
| 8 |
+
# only swaps tokenizer/data/run labels.
|
| 9 |
+
export VOCAB_SIZE="${VOCAB_SIZE:-8192}"
|
| 10 |
+
export DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-compact-gpt2bpe-v8192-stream1024-train-minus-100k}"
|
| 11 |
+
export TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/lta_tokenizers/owt_compact_gpt2bpe_v8192/tokenizer.json}"
|
| 12 |
+
export COMPACT_VARIANT_LABEL="${COMPACT_VARIANT_LABEL:-compact_gpt2bpe_v8192_stream1024_fullycoupled_mask0p1-1p0_wd0p1_fp32}"
|
| 13 |
+
export T_SAMPLING_MODE="${T_SAMPLING_MODE:-uniform}"
|
| 14 |
+
export MIN_MASK_RATIO="${MIN_MASK_RATIO:-0.1}"
|
| 15 |
+
export MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
|
| 16 |
+
|
| 17 |
+
sanitize_label() {
|
| 18 |
+
printf "%s" "$1" | sed -e 's/-/m/g' -e 's/\./p/g'
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
T_SAMPLING_LOGIT_MEAN_FOR_NAME="${T_SAMPLING_LOGIT_MEAN:--1.5}"
|
| 22 |
+
T_SAMPLING_LOGIT_STD_FOR_NAME="${T_SAMPLING_LOGIT_STD:-0.8}"
|
| 23 |
+
MIN_MASK_RATIO_FOR_NAME="${MIN_MASK_RATIO:-1.0}"
|
| 24 |
+
MAX_MASK_RATIO_FOR_NAME="${MAX_MASK_RATIO:-1.0}"
|
| 25 |
+
|
| 26 |
+
T_LOGIT_MEAN_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_MEAN_FOR_NAME}")"
|
| 27 |
+
T_LOGIT_STD_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_STD_FOR_NAME}")"
|
| 28 |
+
MIN_MASK_RATIO_LABEL="$(sanitize_label "${MIN_MASK_RATIO_FOR_NAME}")"
|
| 29 |
+
MAX_MASK_RATIO_LABEL="$(sanitize_label "${MAX_MASK_RATIO_FOR_NAME}")"
|
| 30 |
+
if [[ "${T_SAMPLING_MODE}" == "logit_normal" ]]; then
|
| 31 |
+
T_SAMPLING_LABEL="logitnormal_${T_LOGIT_MEAN_LABEL}_s${T_LOGIT_STD_LABEL}"
|
| 32 |
+
else
|
| 33 |
+
T_SAMPLING_LABEL="$(sanitize_label "${T_SAMPLING_MODE}")t"
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
export RUN_NAME="${RUN_NAME:-lta_owt_compact_gpt2bpe_v8192_stream1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_${T_SAMPLING_LABEL}_hardce_mask${MIN_MASK_RATIO_LABEL}-${MAX_MASK_RATIO_LABEL}_fp32_ddit768x12_gbs512_8gpu_1m_$(date +%Y%m%d_%H%M%S)}"
|
| 37 |
+
export LOG_DIR="${LOG_DIR:-logs/compact_gpt2bpe_v8192_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu}"
|
| 38 |
+
|
| 39 |
+
bash scripts/launch_lta_owt_compact_gpt2bpe_v2048_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu.sh
|
LTA_openwebtext_dualt/scripts/launch_lta_owt_elfaligned_t5_logitnormal_8gpu.sh
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
|
| 7 |
+
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
|
| 8 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 9 |
+
export TOKENIZERS_PARALLELISM=false
|
| 10 |
+
export PYTHONUNBUFFERED=1
|
| 11 |
+
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
|
| 12 |
+
export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
|
| 13 |
+
|
| 14 |
+
# ELF-aligned simplex run:
|
| 15 |
+
# architecture: ddit_elf = no adaLN, prefix time tokens, qk norm, RoPE, RMSNorm, SwiGLU
|
| 16 |
+
# tokenizer/data: T5-small tokenizer, one OWT record per example, pad/truncate to 1024
|
| 17 |
+
# optimizer: Muon, lr 0.002, wd 0, constant LR after 0.5 epoch warmup
|
| 18 |
+
# time sampling: sigmoid(N(T_LOGIT_MEAN, T_LOGIT_STD^2)); defaults match ELF
|
| 19 |
+
# The old ddit path and GPT2 cached scripts are untouched.
|
| 20 |
+
|
| 21 |
+
DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext}"
|
| 22 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
|
| 23 |
+
|
| 24 |
+
NNODES="${NNODES:-1}"
|
| 25 |
+
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
| 26 |
+
NODE_RANK="${NODE_RANK:-0}"
|
| 27 |
+
MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
|
| 28 |
+
MASTER_PORT="${MASTER_PORT:-32091}"
|
| 29 |
+
|
| 30 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
|
| 31 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 32 |
+
EPOCHS="${EPOCHS:-5}"
|
| 33 |
+
NUM_WORKERS="${NUM_WORKERS:-8}"
|
| 34 |
+
DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
|
| 35 |
+
LOG_EVERY="${LOG_EVERY:-100}"
|
| 36 |
+
LATEST_EVERY="${LATEST_EVERY:-1000}"
|
| 37 |
+
EVAL_EVERY="${EVAL_EVERY:-0}"
|
| 38 |
+
ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
|
| 39 |
+
ALLOW_TF32="${ALLOW_TF32:-1}"
|
| 40 |
+
|
| 41 |
+
LR="${LR:-0.002}"
|
| 42 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.0}"
|
| 43 |
+
ADAM_BETA1="${ADAM_BETA1:-0.9}"
|
| 44 |
+
ADAM_BETA2="${ADAM_BETA2:-0.999}"
|
| 45 |
+
ADAM_EPS="${ADAM_EPS:-1e-8}"
|
| 46 |
+
MUON_MOMENTUM="${MUON_MOMENTUM:-0.95}"
|
| 47 |
+
MUON_NS_STEPS="${MUON_NS_STEPS:-5}"
|
| 48 |
+
MUON_UPDATE_SCALE="${MUON_UPDATE_SCALE:-1.0}"
|
| 49 |
+
GRAD_CLIP="${GRAD_CLIP:-1.0}"
|
| 50 |
+
EMA_DECAY="${EMA_DECAY:-0.9999}"
|
| 51 |
+
EMA_START_STEP="${EMA_START_STEP:-0}"
|
| 52 |
+
T_LOGIT_MEAN="${T_LOGIT_MEAN:--1.5}"
|
| 53 |
+
T_LOGIT_STD="${T_LOGIT_STD:-0.8}"
|
| 54 |
+
LOSS_T_WEIGHT_MODE="${LOSS_T_WEIGHT_MODE:-none}"
|
| 55 |
+
LOSS_T_MIN_WEIGHT="${LOSS_T_MIN_WEIGHT:-0.0}"
|
| 56 |
+
OUTPUT_INIT_STD="${OUTPUT_INIT_STD:-0.0}"
|
| 57 |
+
|
| 58 |
+
sanitize_label() {
|
| 59 |
+
printf "%s" "$1" | sed -e 's/-/m/g' -e 's/\./p/g'
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
T_LOGIT_MEAN_LABEL="$(sanitize_label "${T_LOGIT_MEAN}")"
|
| 63 |
+
T_LOGIT_STD_LABEL="$(sanitize_label "${T_LOGIT_STD}")"
|
| 64 |
+
LOSS_T_MIN_WEIGHT_LABEL="$(sanitize_label "${LOSS_T_MIN_WEIGHT}")"
|
| 65 |
+
|
| 66 |
+
RUN_NAME="${RUN_NAME:-lta_owt_t5record_len1024_elfaligned_dditelf_muon_logitnormal_${T_LOGIT_MEAN_LABEL}_s${T_LOGIT_STD_LABEL}_${LOSS_T_WEIGHT_MODE}_floor${LOSS_T_MIN_WEIGHT_LABEL}_gbs512_8gpu_5epoch_$(date +%Y%m%d_%H%M%S)}"
|
| 67 |
+
SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
|
| 68 |
+
LOG_DIR="${LOG_DIR:-logs/elfaligned_t5record_8gpu}"
|
| 69 |
+
LOG_FILE="${LOG_FILE:-${LOG_DIR}/${RUN_NAME}.log}"
|
| 70 |
+
|
| 71 |
+
NUM_RECORDS=$(python - <<PY
|
| 72 |
+
from pathlib import Path
|
| 73 |
+
import pyarrow.parquet as pq
|
| 74 |
+
root = Path("${DATA_PATH}")
|
| 75 |
+
files = sorted(root.rglob("*.parquet")) if root.is_dir() else [root]
|
| 76 |
+
rows = sum(pq.ParquetFile(str(p)).metadata.num_rows for p in files)
|
| 77 |
+
print(max(0, rows - 100_000))
|
| 78 |
+
PY
|
| 79 |
+
)
|
| 80 |
+
STEPS_PER_EPOCH=$(( (NUM_RECORDS + GLOBAL_BATCH_SIZE - 1) / GLOBAL_BATCH_SIZE ))
|
| 81 |
+
SAVE_EVERY="${SAVE_EVERY:-${STEPS_PER_EPOCH}}"
|
| 82 |
+
|
| 83 |
+
if [[ -f "${SAVE_DIR}/args.json" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
|
| 84 |
+
echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
|
| 85 |
+
echo "Use a new RUN_NAME/SAVE_DIR or set ALLOW_EXISTING_SAVE_DIR=1 intentionally." >&2
|
| 86 |
+
exit 2
|
| 87 |
+
fi
|
| 88 |
+
|
| 89 |
+
mkdir -p "${LOG_DIR}" "${SAVE_DIR}"
|
| 90 |
+
|
| 91 |
+
TF32_FLAG="--allow_tf32"
|
| 92 |
+
TF32_LABEL="true"
|
| 93 |
+
if [[ "${ALLOW_TF32}" == "0" || "${ALLOW_TF32}" == "false" || "${ALLOW_TF32}" == "False" ]]; then
|
| 94 |
+
TF32_FLAG="--no-allow_tf32"
|
| 95 |
+
TF32_LABEL="false"
|
| 96 |
+
fi
|
| 97 |
+
|
| 98 |
+
echo "[launch] method=owt_elfaligned_t5record_dditelf host=$(hostname) time=$(date -Iseconds)"
|
| 99 |
+
echo "[launch] run_name=${RUN_NAME}"
|
| 100 |
+
echo "[launch] save_dir=${SAVE_DIR}"
|
| 101 |
+
echo "[launch] log_file=${LOG_FILE}"
|
| 102 |
+
echo "[launch] data_path=${DATA_PATH}"
|
| 103 |
+
echo "[launch] tokenizer=${TOKENIZER_PATH}"
|
| 104 |
+
echo "[launch] records=${NUM_RECORDS} epochs=${EPOCHS} approx_steps_per_epoch=${STEPS_PER_EPOCH} save_every=${SAVE_EVERY}"
|
| 105 |
+
echo "[launch] optimizer=muon_impl=optax grouping=hidden_2d lr=${LR} wd=${WEIGHT_DECAY} adam_fallback_wd=0 momentum=${MUON_MOMENTUM} ns=${MUON_NS_STEPS} nesterov=true width_scale=true adam_fallback_b2=${ADAM_BETA2} ema=${EMA_DECAY}"
|
| 106 |
+
echo "[launch] model=ddit_elf rmsnorm qk_norm=true swiglu no_adaln output_bias=false output_init_std=${OUTPUT_INIT_STD} time_tokens=4 mode_tokens=0"
|
| 107 |
+
echo "[launch] data=record_pad_truncate pad=pad add_special_tokens=false t5-small fp32=true bf16=false tf32=${TF32_LABEL}"
|
| 108 |
+
echo "[launch] t_sampling=logit_normal mean=${T_LOGIT_MEAN} std=${T_LOGIT_STD} loss_t_weight=${LOSS_T_WEIGHT_MODE} loss_t_min_weight=${LOSS_T_MIN_WEIGHT} warmup_epochs=0.5"
|
| 109 |
+
|
| 110 |
+
python -m torch.distributed.run \
|
| 111 |
+
--nnodes="${NNODES}" \
|
| 112 |
+
--nproc_per_node="${NPROC_PER_NODE}" \
|
| 113 |
+
--node_rank="${NODE_RANK}" \
|
| 114 |
+
--master_addr="${MASTER_ADDR}" \
|
| 115 |
+
--master_port="${MASTER_PORT}" \
|
| 116 |
+
train.py \
|
| 117 |
+
--data_path "${DATA_PATH}" \
|
| 118 |
+
--openwebtext_split train_minus_100k \
|
| 119 |
+
--text_column text \
|
| 120 |
+
--detokenizer auto \
|
| 121 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 122 |
+
--save_dir "${SAVE_DIR}" \
|
| 123 |
+
--record_pad_truncate \
|
| 124 |
+
--record_pad_token pad \
|
| 125 |
+
--record_shuffle_buffer 10000 \
|
| 126 |
+
--max_len 1024 \
|
| 127 |
+
--batch_size "${PER_GPU_BATCH_SIZE}" \
|
| 128 |
+
--global_batch_size "${GLOBAL_BATCH_SIZE}" \
|
| 129 |
+
--num_workers "${NUM_WORKERS}" \
|
| 130 |
+
--dataloader_prefetch_factor "${DATALOADER_PREFETCH_FACTOR}" \
|
| 131 |
+
--epochs "${EPOCHS}" \
|
| 132 |
+
--total_steps 1 \
|
| 133 |
+
--warmup_epochs 0.5 \
|
| 134 |
+
--log_every "${LOG_EVERY}" \
|
| 135 |
+
--eval_every "${EVAL_EVERY}" \
|
| 136 |
+
--save_every "${SAVE_EVERY}" \
|
| 137 |
+
--latest_every "${LATEST_EVERY}" \
|
| 138 |
+
--optimizer muon \
|
| 139 |
+
--muon_impl optax \
|
| 140 |
+
--lr "${LR}" \
|
| 141 |
+
--lr_schedule constant_warmup \
|
| 142 |
+
--min_lr 0 \
|
| 143 |
+
--weight_decay "${WEIGHT_DECAY}" \
|
| 144 |
+
--adam_beta1 "${ADAM_BETA1}" \
|
| 145 |
+
--adam_beta2 "${ADAM_BETA2}" \
|
| 146 |
+
--adam_eps "${ADAM_EPS}" \
|
| 147 |
+
--muon_momentum "${MUON_MOMENTUM}" \
|
| 148 |
+
--muon_ns_steps "${MUON_NS_STEPS}" \
|
| 149 |
+
--muon_update_scale "${MUON_UPDATE_SCALE}" \
|
| 150 |
+
--muon_nesterov \
|
| 151 |
+
--muon_width_scale \
|
| 152 |
+
--ema_decay "${EMA_DECAY}" \
|
| 153 |
+
--ema_start_step "${EMA_START_STEP}" \
|
| 154 |
+
--grad_clip "${GRAD_CLIP}" \
|
| 155 |
+
--seed 42 \
|
| 156 |
+
--d_model 768 \
|
| 157 |
+
--cond_dim 128 \
|
| 158 |
+
--n_layers 12 \
|
| 159 |
+
--n_heads 12 \
|
| 160 |
+
--dim_ff 3072 \
|
| 161 |
+
--dropout 0.0 \
|
| 162 |
+
--no-output_bias \
|
| 163 |
+
--output_init_std "${OUTPUT_INIT_STD}" \
|
| 164 |
+
--norm_type rmsnorm \
|
| 165 |
+
--model_type ddit_elf \
|
| 166 |
+
--elf_num_time_tokens 4 \
|
| 167 |
+
--elf_num_model_mode_tokens 0 \
|
| 168 |
+
--qk_norm \
|
| 169 |
+
--state_format prob \
|
| 170 |
+
--bridge dirichlet \
|
| 171 |
+
--target_loss hard_ce \
|
| 172 |
+
--loss_t_weight_mode "${LOSS_T_WEIGHT_MODE}" \
|
| 173 |
+
--loss_t_min_weight "${LOSS_T_MIN_WEIGHT}" \
|
| 174 |
+
--target_prob 1.0 \
|
| 175 |
+
--min_t 0.0 \
|
| 176 |
+
--max_t 1.0 \
|
| 177 |
+
--t_sampling_mode logit_normal \
|
| 178 |
+
--t_sampling_logit_mean "${T_LOGIT_MEAN}" \
|
| 179 |
+
--t_sampling_logit_std "${T_LOGIT_STD}" \
|
| 180 |
+
--t_sampling_eps 1e-4 \
|
| 181 |
+
--dual_t \
|
| 182 |
+
--corrupt_t_mode same \
|
| 183 |
+
--corrupt_min_t 0.0 \
|
| 184 |
+
--corrupt_max_t 1.0 \
|
| 185 |
+
--min_mask_ratio 0.1 \
|
| 186 |
+
--max_mask_ratio 1.0 \
|
| 187 |
+
--wrong_token_replace_prob 1.0 \
|
| 188 |
+
--wrong_token_schedule linear_t \
|
| 189 |
+
--wrong_token_exp_k 1.0 \
|
| 190 |
+
--dirichlet_concentration_min 1.0 \
|
| 191 |
+
--dirichlet_concentration_max 1024 \
|
| 192 |
+
--dirichlet_endpoint_mode categorical_dual_t \
|
| 193 |
+
--dirichlet_semantic_t_mode same \
|
| 194 |
+
--dirichlet_semantic_t_value 0.0 \
|
| 195 |
+
--categorical_wrong_from_full_vocab \
|
| 196 |
+
--simplex_bridge_sampler dirichlet \
|
| 197 |
+
--eps 1e-8 \
|
| 198 |
+
--infer_steps 1024 \
|
| 199 |
+
--decode_damping 1.0 \
|
| 200 |
+
--max_gamma 1.0 \
|
| 201 |
+
--decode_solver flowmap \
|
| 202 |
+
--noise_init logistic_normal \
|
| 203 |
+
--bridge_noise_init logistic_normal \
|
| 204 |
+
--noise_sigma -1 \
|
| 205 |
+
"${TF32_FLAG}" \
|
| 206 |
+
--activation_checkpointing \
|
| 207 |
+
--activation_checkpoint_scope mlp \
|
| 208 |
+
--ddp_gradient_as_bucket_view \
|
| 209 |
+
2>&1 | tee -a "${LOG_FILE}"
|
LTA_openwebtext_dualt/scripts/launch_lta_owt_fullycoupled_outwd0p5_8gpu.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:-0.5}"
|
| 7 |
+
export WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
|
| 8 |
+
export RUN_NAME="${RUN_NAME:-lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_outwd0p5_nanogpt_tf32_ddit768x12_gbs512_8gpu_1m_$(date +%Y%m%d_%H%M%S)}"
|
| 9 |
+
export LOG_DIR="${LOG_DIR:-logs/fullycoupled_outwd0p5_8gpu}"
|
| 10 |
+
|
| 11 |
+
bash scripts/launch_lta_owt_fullycoupled_wd0p1_fp32_8gpu.sh
|
LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_rollin_grad_k1_rho025_subset10k_4gpu_100k.sh
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}"
|
| 7 |
+
export NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
|
| 8 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 9 |
+
export TOKENIZERS_PARALLELISM=false
|
| 10 |
+
export PYTHONUNBUFFERED=1
|
| 11 |
+
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
|
| 12 |
+
|
| 13 |
+
free_port() {
|
| 14 |
+
python3 - <<'PY'
|
| 15 |
+
import socket
|
| 16 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 17 |
+
s.bind(("127.0.0.1", 0))
|
| 18 |
+
print(s.getsockname()[1])
|
| 19 |
+
PY
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-t5}"
|
| 23 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
|
| 24 |
+
MAX_RECORDS="${MAX_RECORDS:-10000}"
|
| 25 |
+
TOTAL_STEPS="${TOTAL_STEPS:-100000}"
|
| 26 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 27 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-16}"
|
| 28 |
+
RUN_NAME="${RUN_NAME:-lta_owt_t5_3l_d256_rollin_grad_p50_k1_rho0_0p25_uniformt_maxrec10k_4gpu_100k_$(date +%Y%m%d_%H%M%S)}"
|
| 29 |
+
MASTER_PORT="${MASTER_PORT:-$(free_port)}"
|
| 30 |
+
LOG_DIR="${LOG_DIR:-logs/elfaligned_t5tokenized_4gpu}"
|
| 31 |
+
mkdir -p "${LOG_DIR}" "runs/${RUN_NAME}"
|
| 32 |
+
LOG_FILE="${LOG_DIR}/${RUN_NAME}.log"
|
| 33 |
+
|
| 34 |
+
echo "[launch] run_name=${RUN_NAME}" | tee -a "${LOG_FILE}"
|
| 35 |
+
echo "[launch] data=${DATA_PATH} max_records=${MAX_RECORDS} tokenizer=${TOKENIZER_PATH}" | tee -a "${LOG_FILE}"
|
| 36 |
+
echo "[launch] cuda=${CUDA_VISIBLE_DEVICES} nproc=${NPROC_PER_NODE} gbs=${GLOBAL_BATCH_SIZE} per_gpu=${PER_GPU_BATCH_SIZE} total_steps=${TOTAL_STEPS}" | tee -a "${LOG_FILE}"
|
| 37 |
+
|
| 38 |
+
torchrun \
|
| 39 |
+
--nproc_per_node="${NPROC_PER_NODE}" \
|
| 40 |
+
--master_port="${MASTER_PORT}" \
|
| 41 |
+
train.py \
|
| 42 |
+
--data_path "${DATA_PATH}" \
|
| 43 |
+
--max_records "${MAX_RECORDS}" \
|
| 44 |
+
--tokenized_hf \
|
| 45 |
+
--tokenized_pad_token pad \
|
| 46 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 47 |
+
--save_dir "runs/${RUN_NAME}" \
|
| 48 |
+
--max_len 1024 \
|
| 49 |
+
--batch_size "${PER_GPU_BATCH_SIZE}" \
|
| 50 |
+
--global_batch_size "${GLOBAL_BATCH_SIZE}" \
|
| 51 |
+
--num_workers 0 \
|
| 52 |
+
--epochs 0 \
|
| 53 |
+
--total_steps "${TOTAL_STEPS}" \
|
| 54 |
+
--warmup_steps 1 \
|
| 55 |
+
--warmup_epochs 0.5 \
|
| 56 |
+
--log_every 100 \
|
| 57 |
+
--eval_every 0 \
|
| 58 |
+
--save_every 5000 \
|
| 59 |
+
--latest_every 1000 \
|
| 60 |
+
--optimizer muon \
|
| 61 |
+
--muon_impl optax \
|
| 62 |
+
--lr 0.002 \
|
| 63 |
+
--lr_schedule constant_warmup \
|
| 64 |
+
--min_lr 0.0 \
|
| 65 |
+
--weight_decay 0.1 \
|
| 66 |
+
--output_weight_decay -1 \
|
| 67 |
+
--adamw_param_groups nanogpt \
|
| 68 |
+
--adam_beta1 0.9 \
|
| 69 |
+
--adam_beta2 0.999 \
|
| 70 |
+
--adam_eps 1e-8 \
|
| 71 |
+
--ema_decay 0.9999 \
|
| 72 |
+
--ema_start_step 0 \
|
| 73 |
+
--grad_clip 1.0 \
|
| 74 |
+
--seed 42 \
|
| 75 |
+
--d_model 256 \
|
| 76 |
+
--cond_dim 128 \
|
| 77 |
+
--n_layers 3 \
|
| 78 |
+
--n_heads 4 \
|
| 79 |
+
--dim_ff 1024 \
|
| 80 |
+
--dropout 0.0 \
|
| 81 |
+
--no-output_bias \
|
| 82 |
+
--output_init_std 0 \
|
| 83 |
+
--norm_type rmsnorm \
|
| 84 |
+
--qk_norm \
|
| 85 |
+
--model_type ddit_elf \
|
| 86 |
+
--ddit_mlp_type gelu \
|
| 87 |
+
--state_format prob \
|
| 88 |
+
--bridge dirichlet \
|
| 89 |
+
--target_loss hard_ce \
|
| 90 |
+
--loss_t_weight_mode none \
|
| 91 |
+
--loss_t_min_weight 0.0 \
|
| 92 |
+
--rollout_train_prob 0.50 \
|
| 93 |
+
--rollout_train_time_mode sampled_path \
|
| 94 |
+
--rollout_train_steps 1 \
|
| 95 |
+
--rollout_train_steps_min -1 \
|
| 96 |
+
--rollout_train_infer_steps 1 \
|
| 97 |
+
--rollout_train_s_dist uniform \
|
| 98 |
+
--rollout_train_s_min_frac 0.0 \
|
| 99 |
+
--rollout_train_s_max_frac 0.25 \
|
| 100 |
+
--rollout_train_temp 1.0 \
|
| 101 |
+
--rollout_train_max_gamma 1.0 \
|
| 102 |
+
--rollout_train_corrupt_only \
|
| 103 |
+
--rollout_train_samplewise \
|
| 104 |
+
--rollout_train_selected_only \
|
| 105 |
+
--no-rollout_train_compute_always \
|
| 106 |
+
--rollout_train_keep_grad \
|
| 107 |
+
--rollout_train_sync_t \
|
| 108 |
+
--target_prob 1.0 \
|
| 109 |
+
--min_t 0.0 \
|
| 110 |
+
--max_t 1.0 \
|
| 111 |
+
--t_sampling_mode uniform \
|
| 112 |
+
--t_sampling_logit_mean -1.5 \
|
| 113 |
+
--t_sampling_logit_std 0.8 \
|
| 114 |
+
--t_sampling_eps 1e-4 \
|
| 115 |
+
--dual_t \
|
| 116 |
+
--corrupt_t_mode same \
|
| 117 |
+
--corrupt_min_t 0.0 \
|
| 118 |
+
--corrupt_max_t 1.0 \
|
| 119 |
+
--min_mask_ratio 1.0 \
|
| 120 |
+
--max_mask_ratio 1.0 \
|
| 121 |
+
--mask_mixture_original_prob 0.0 \
|
| 122 |
+
--mask_mixture_lowk_prob 0.0 \
|
| 123 |
+
--mask_mixture_lowcorrupt_prob 0.0 \
|
| 124 |
+
--mask_mixture_block_prob 0.0 \
|
| 125 |
+
--mask_mixture_all_prob 1.0 \
|
| 126 |
+
--wrong_token_replace_prob 1.0 \
|
| 127 |
+
--wrong_token_schedule linear_t \
|
| 128 |
+
--wrong_token_exp_k 1.0 \
|
| 129 |
+
--dirichlet_concentration_min 1.0 \
|
| 130 |
+
--dirichlet_concentration_max 1024 \
|
| 131 |
+
--dirichlet_endpoint_mode categorical_dual_t \
|
| 132 |
+
--dirichlet_semantic_t_mode same \
|
| 133 |
+
--dirichlet_semantic_t_value 0.0 \
|
| 134 |
+
--categorical_wrong_from_full_vocab \
|
| 135 |
+
--simplex_bridge_sampler dirichlet \
|
| 136 |
+
--eps 1e-8 \
|
| 137 |
+
--infer_steps 1024 \
|
| 138 |
+
--decode_damping 1.0 \
|
| 139 |
+
--max_gamma 1.0 \
|
| 140 |
+
--decode_solver flowmap \
|
| 141 |
+
--noise_init logistic_normal \
|
| 142 |
+
--bridge_noise_init logistic_normal \
|
| 143 |
+
--noise_sigma -1 \
|
| 144 |
+
--allow_tf32 \
|
| 145 |
+
--activation_checkpointing \
|
| 146 |
+
--activation_checkpoint_scope mlp \
|
| 147 |
+
--ddp_gradient_as_bucket_view \
|
| 148 |
+
2>&1 | tee -a "${LOG_FILE}"
|
LTA_openwebtext_dualt/scripts/run_lta_lm1b_dirichlet_len1024_Cv_to_2v_8gpu_1m_save10k.sh
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
|
| 7 |
+
export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
| 8 |
+
export MASTER_PORT="${MASTER_PORT:-32682}"
|
| 9 |
+
|
| 10 |
+
export GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 11 |
+
export PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
|
| 12 |
+
export TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
|
| 13 |
+
export WARMUP_STEPS="${WARMUP_STEPS:-2500}"
|
| 14 |
+
export SAVE_EVERY="${SAVE_EVERY:-10000}"
|
| 15 |
+
export LATEST_EVERY="${LATEST_EVERY:-1000}"
|
| 16 |
+
export LOG_EVERY="${LOG_EVERY:-100}"
|
| 17 |
+
|
| 18 |
+
export MAX_LEN="${MAX_LEN:-1024}"
|
| 19 |
+
export VOCAB_SIZE="${VOCAB_SIZE:-30522}"
|
| 20 |
+
export CMIN="${CMIN:-${VOCAB_SIZE}}"
|
| 21 |
+
export CMAX="${CMAX:-61044}"
|
| 22 |
+
|
| 23 |
+
export MIN_MASK_RATIO="${MIN_MASK_RATIO:-0.1}"
|
| 24 |
+
export MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
|
| 25 |
+
export CATEGORICAL_WRONG_PROB_FLOOR="${CATEGORICAL_WRONG_PROB_FLOOR:-0.0}"
|
| 26 |
+
|
| 27 |
+
# Keep watcher off by default for the 1M run; enable explicitly to avoid
|
| 28 |
+
# competing with training GPUs on busy 8-card nodes.
|
| 29 |
+
export WATCH_ENABLED="${WATCH_ENABLED:-0}"
|
| 30 |
+
|
| 31 |
+
DATE_TAG="${DATE_TAG:-$(date +%Y%m%d)}"
|
| 32 |
+
export RUN_NAME="${RUN_NAME:-lta_lm1b_dirichlet_len1024_Cv_to_2v_gbs512_b32_8gpu_1m_save10k_${DATE_TAG}}"
|
| 33 |
+
|
| 34 |
+
bash scripts/run_lta_lm1b_dirichlet_len1024_Cv_to_2v_8gpu_save1k_with_gumbel_watch.sh
|
LTA_openwebtext_dualt/scripts/run_lta_owt_dirichlet_len1024_Cv_to_2v_8gpu_1m_save10k.sh
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
|
| 7 |
+
export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
| 8 |
+
export MASTER_PORT="${MASTER_PORT:-32682}"
|
| 9 |
+
|
| 10 |
+
export GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 11 |
+
export PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
|
| 12 |
+
export TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
|
| 13 |
+
export WARMUP_STEPS="${WARMUP_STEPS:-2500}"
|
| 14 |
+
export SAVE_EVERY="${SAVE_EVERY:-10000}"
|
| 15 |
+
export LATEST_EVERY="${LATEST_EVERY:-1000}"
|
| 16 |
+
export LOG_EVERY="${LOG_EVERY:-100}"
|
| 17 |
+
|
| 18 |
+
export MAX_LEN="${MAX_LEN:-1024}"
|
| 19 |
+
export VOCAB_SIZE="${VOCAB_SIZE:-30522}"
|
| 20 |
+
export CMIN="${CMIN:-${VOCAB_SIZE}}"
|
| 21 |
+
export CMAX="${CMAX:-61044}"
|
| 22 |
+
|
| 23 |
+
export MIN_MASK_RATIO="${MIN_MASK_RATIO:-0.1}"
|
| 24 |
+
export MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
|
| 25 |
+
export CATEGORICAL_WRONG_PROB_FLOOR="${CATEGORICAL_WRONG_PROB_FLOOR:-0.0}"
|
| 26 |
+
|
| 27 |
+
# Keep watcher off by default for the 1M run; enable explicitly to avoid
|
| 28 |
+
# competing with training GPUs on busy 8-card nodes.
|
| 29 |
+
export WATCH_ENABLED="${WATCH_ENABLED:-0}"
|
| 30 |
+
|
| 31 |
+
DATE_TAG="${DATE_TAG:-$(date +%Y%m%d)}"
|
| 32 |
+
export RUN_NAME="${RUN_NAME:-lta_owt_dirichlet_len1024_Cv_to_2v_gbs512_b32_8gpu_1m_save10k_${DATE_TAG}}"
|
| 33 |
+
|
| 34 |
+
bash scripts/run_lta_owt_dirichlet_len1024_Cv_to_2v_8gpu_save1k_with_gumbel_watch.sh
|
LTA_openwebtext_dualt/scripts/run_lta_owt_t5_absrope_adaln_dirichlet_len1024_Cv_to_2v_8gpu_mask0p1_1p0_sameT_1m_save10k.sh
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
# T5-tokenized OWT, DDiT = RoPE + adaLN-zero, with learned absolute position
|
| 7 |
+
# embeddings added before RoPE. The bridge/model t is shared (sameT).
|
| 8 |
+
export DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-t5}"
|
| 9 |
+
export TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
|
| 10 |
+
export TOKENIZED_HF="${TOKENIZED_HF:-1}"
|
| 11 |
+
export TOKENIZED_PAD_TOKEN="${TOKENIZED_PAD_TOKEN:-pad}"
|
| 12 |
+
|
| 13 |
+
export VOCAB_SIZE="${VOCAB_SIZE:-32100}"
|
| 14 |
+
export CMIN="${CMIN:-32100}"
|
| 15 |
+
export CMAX="${CMAX:-64200}"
|
| 16 |
+
|
| 17 |
+
export ABS_POS_EMBED="${ABS_POS_EMBED:-1}"
|
| 18 |
+
export CORRUPT_T_MODE="${CORRUPT_T_MODE:-same}"
|
| 19 |
+
export MIN_MASK_RATIO="${MIN_MASK_RATIO:-0.1}"
|
| 20 |
+
export MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
|
| 21 |
+
export MASK_MIXTURE_ORIGINAL_PROB="${MASK_MIXTURE_ORIGINAL_PROB:-0.0}"
|
| 22 |
+
export MASK_MIXTURE_ALL_PROB="${MASK_MIXTURE_ALL_PROB:-0.0}"
|
| 23 |
+
|
| 24 |
+
export DATE_TAG="${DATE_TAG:-$(date +%Y%m%d)}"
|
| 25 |
+
export TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
|
| 26 |
+
export SAVE_EVERY="${SAVE_EVERY:-10000}"
|
| 27 |
+
export LATEST_EVERY="${LATEST_EVERY:-1000}"
|
| 28 |
+
export WATCH_ENABLED="${WATCH_ENABLED:-1}"
|
| 29 |
+
export WATCH_STEP_INTERVAL="${WATCH_STEP_INTERVAL:-10000}"
|
| 30 |
+
export WATCH_N_SAMPLES="${WATCH_N_SAMPLES:-128}"
|
| 31 |
+
export WATCH_CUDA_VISIBLE_DEVICES="${WATCH_CUDA_VISIBLE_DEVICES:-7}"
|
| 32 |
+
|
| 33 |
+
export RUN_NAME="${RUN_NAME:-lta_owt_t5_absrope_adaln_dirichlet_len1024_Cv_to_2v_mask0p1_1p0_sameT_gbs512_b32_8gpu_1m_save10k_${DATE_TAG}}"
|
| 34 |
+
export WATCH_OUT_BASE="${WATCH_OUT_BASE:-docs/lta_samples/metrics_${DATE_TAG}/owt_t5_absrope_adaln_Cv_to_2v_mask0p1_1p0_sameT_sde_gumbel_topp${WATCH_ENDPOINT_TOP_P:-0.95}_tau${WATCH_GUMBEL_TAU_START:-1.0}_to_${WATCH_GUMBEL_TAU_END:-0.2}_blend_c${CMIN}_${CMAX}_n${WATCH_N_SAMPLES}/${RUN_NAME}}"
|
| 35 |
+
|
| 36 |
+
bash scripts/run_lta_owt_dirichlet_len1024_Cv_to_2v_8gpu_save1k_with_gumbel_watch.sh
|
LTA_openwebtext_dualt/scripts/run_train8_wrong_floor_pilots_4gpu.sh
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 7 |
+
export TOKENIZERS_PARALLELISM=false
|
| 8 |
+
export PYTHONUNBUFFERED=1
|
| 9 |
+
|
| 10 |
+
BASE_CACHE="${BASE_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks}"
|
| 11 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
|
| 12 |
+
MAX_LEN="${MAX_LEN:-256}"
|
| 13 |
+
N_SAMPLES="${N_SAMPLES:-64}"
|
| 14 |
+
INFER_STEPS="${INFER_STEPS:-128}"
|
| 15 |
+
STEP_CHUNK="${STEP_CHUNK:-1000}"
|
| 16 |
+
MAX_TOTAL_STEPS="${MAX_TOTAL_STEPS:-20000}"
|
| 17 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-128}"
|
| 18 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 19 |
+
GROUP_STAMP="${GROUP_STAMP:-$(date +%Y%m%d_%H%M%S)}"
|
| 20 |
+
OUT_ROOT="${OUT_ROOT:-docs/lta_samples/metrics_20260517/wrong_floor_pilots_len${MAX_LEN}_bs512_ode128_${GROUP_STAMP}}"
|
| 21 |
+
DRIVER_LOG="${DRIVER_LOG:-logs/wrong_floor_pilots_4gpu/${GROUP_STAMP}.log}"
|
| 22 |
+
CURVE_CSV="${CURVE_CSV:-${OUT_ROOT}/hit_ratio_curve.csv}"
|
| 23 |
+
mkdir -p "$(dirname "${DRIVER_LOG}")" "${OUT_ROOT}"
|
| 24 |
+
|
| 25 |
+
cache="${BASE_CACHE}/gpt2_len${MAX_LEN}_train8_compact_overfit"
|
| 26 |
+
vocab_size="$(
|
| 27 |
+
python - "$cache" <<'PY'
|
| 28 |
+
import json
|
| 29 |
+
import sys
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
meta = json.loads((Path(sys.argv[1]) / "meta.json").read_text())
|
| 32 |
+
print(int(meta.get("compact_vocab_size", meta.get("vocab_size"))))
|
| 33 |
+
PY
|
| 34 |
+
)"
|
| 35 |
+
|
| 36 |
+
if [[ ! -f "${CURVE_CSV}" ]]; then
|
| 37 |
+
echo "config,ckpt_step,train_views_seen,train_tokens_seen,token_acc_mean,exact_count,exact_ref_count,exact_ref_hits" > "${CURVE_CSV}"
|
| 38 |
+
fi
|
| 39 |
+
|
| 40 |
+
latest_step() {
|
| 41 |
+
local run_name="$1"
|
| 42 |
+
python - "$run_name" <<'PY'
|
| 43 |
+
import re
|
| 44 |
+
import sys
|
| 45 |
+
from pathlib import Path
|
| 46 |
+
run = Path("runs") / sys.argv[1]
|
| 47 |
+
steps = []
|
| 48 |
+
for path in run.glob("step_*.pt"):
|
| 49 |
+
m = re.search(r"step_(\d+)\.pt$", path.name)
|
| 50 |
+
if m:
|
| 51 |
+
steps.append(int(m.group(1)))
|
| 52 |
+
print(max(steps) if steps else 0)
|
| 53 |
+
PY
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
free_port() {
|
| 57 |
+
python - <<'PY'
|
| 58 |
+
import socket
|
| 59 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 60 |
+
s.bind(("127.0.0.1", 0))
|
| 61 |
+
print(s.getsockname()[1])
|
| 62 |
+
PY
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
eval_latest() {
|
| 66 |
+
local config="$1"
|
| 67 |
+
local run_name="$2"
|
| 68 |
+
local target_step="$3"
|
| 69 |
+
local out_dir="${OUT_ROOT}/${config}/step_${target_step}"
|
| 70 |
+
mkdir -p "${out_dir}"
|
| 71 |
+
CUDA_VISIBLE_DEVICES="${EVAL_CUDA_VISIBLE_DEVICES:-0}" python scripts/eval_train8_decode_acc.py \
|
| 72 |
+
--runs_glob "runs/${run_name}" \
|
| 73 |
+
--data_dir "${cache}" \
|
| 74 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 75 |
+
--out_dir "${out_dir}" \
|
| 76 |
+
--max_len "${MAX_LEN}" \
|
| 77 |
+
--n_samples "${N_SAMPLES}" \
|
| 78 |
+
--batch_size "${N_SAMPLES}" \
|
| 79 |
+
--latest_only \
|
| 80 |
+
--endpoint_softenings none \
|
| 81 |
+
--steps "${INFER_STEPS}" \
|
| 82 |
+
--decode_rule flowmap \
|
| 83 |
+
--time_schedule logit_normal \
|
| 84 |
+
--time_logit_mean -1.5 \
|
| 85 |
+
--time_logit_std 0.8 \
|
| 86 |
+
--model_t_mode post \
|
| 87 |
+
--c_min 1 \
|
| 88 |
+
--c_max 512 \
|
| 89 |
+
--late_temp 1.0 \
|
| 90 |
+
--final_from state \
|
| 91 |
+
--final_decode argmax
|
| 92 |
+
python - "$config" "$out_dir" "$N_SAMPLES" "$GLOBAL_BATCH_SIZE" "$MAX_LEN" "$CURVE_CSV" <<'PY'
|
| 93 |
+
import json
|
| 94 |
+
import sys
|
| 95 |
+
from pathlib import Path
|
| 96 |
+
config = sys.argv[1]
|
| 97 |
+
out = Path(sys.argv[2])
|
| 98 |
+
n = int(sys.argv[3])
|
| 99 |
+
global_batch = int(sys.argv[4])
|
| 100 |
+
max_len = int(sys.argv[5])
|
| 101 |
+
curve = Path(sys.argv[6])
|
| 102 |
+
row = json.loads((out / "decode_token_acc.jsonl").read_text().splitlines()[-1])
|
| 103 |
+
views = int(row["ckpt_step"]) * global_batch
|
| 104 |
+
tokens = views * max_len
|
| 105 |
+
print(
|
| 106 |
+
"RESULT "
|
| 107 |
+
f"config={config} ckpt_step={row['ckpt_step']} views={views} "
|
| 108 |
+
f"token_acc={row['token_acc_mean']:.4f} exact={row['exact_count']}/{n} "
|
| 109 |
+
f"exact_refs={row['exact_ref_count']} hits={row['exact_ref_hits']}",
|
| 110 |
+
flush=True,
|
| 111 |
+
)
|
| 112 |
+
with curve.open("a", encoding="utf-8") as f:
|
| 113 |
+
f.write(
|
| 114 |
+
f"{config},{row['ckpt_step']},{views},{tokens},{row['token_acc_mean']},"
|
| 115 |
+
f"{row['exact_count']},{row['exact_ref_count']},\"{row['exact_ref_hits']}\"\n"
|
| 116 |
+
)
|
| 117 |
+
PY
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
configs=(
|
| 121 |
+
wrongfloor0p3
|
| 122 |
+
wrongfloor0p5
|
| 123 |
+
wrongfloor0p7
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
echo "[wrong-floor] start stamp=${GROUP_STAMP} len=${MAX_LEN} vocab=${vocab_size} out=${OUT_ROOT}" | tee -a "${DRIVER_LOG}"
|
| 127 |
+
round_idx=0
|
| 128 |
+
while :; do
|
| 129 |
+
round_idx=$((round_idx + 1))
|
| 130 |
+
active=0
|
| 131 |
+
echo "[wrong-floor] round=${round_idx} $(date)" | tee -a "${DRIVER_LOG}"
|
| 132 |
+
for config in "${configs[@]}"; do
|
| 133 |
+
floor="${config#wrongfloor}"
|
| 134 |
+
floor="${floor//p/.}"
|
| 135 |
+
run_name="train8_wrongfloor_len${MAX_LEN}_${config}_${GROUP_STAMP}"
|
| 136 |
+
step_now="$(latest_step "${run_name}")"
|
| 137 |
+
if [[ "${step_now}" -ge "${MAX_TOTAL_STEPS}" ]]; then
|
| 138 |
+
echo "[wrong-floor] capped config=${config} step=${step_now}" | tee -a "${DRIVER_LOG}"
|
| 139 |
+
continue
|
| 140 |
+
fi
|
| 141 |
+
active=1
|
| 142 |
+
target_step=$((step_now + STEP_CHUNK))
|
| 143 |
+
if [[ "${target_step}" -gt "${MAX_TOTAL_STEPS}" ]]; then
|
| 144 |
+
target_step="${MAX_TOTAL_STEPS}"
|
| 145 |
+
fi
|
| 146 |
+
resume_path=""
|
| 147 |
+
if [[ -f "runs/${run_name}/latest.pt" ]]; then
|
| 148 |
+
resume_path="runs/${run_name}/latest.pt"
|
| 149 |
+
fi
|
| 150 |
+
echo "[wrong-floor] train config=${config} floor=${floor} from=${step_now} to=${target_step}" | tee -a "${DRIVER_LOG}"
|
| 151 |
+
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}" \
|
| 152 |
+
NPROC_PER_NODE="${NPROC_PER_NODE:-4}" \
|
| 153 |
+
MASTER_PORT="$(free_port)" \
|
| 154 |
+
OWT_CHUNK_CACHE_DIR="${cache}" \
|
| 155 |
+
OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK:-64}" \
|
| 156 |
+
MAX_LEN="${MAX_LEN}" \
|
| 157 |
+
VOCAB_SIZE_OVERRIDE="${vocab_size}" \
|
| 158 |
+
D_MODEL="${D_MODEL:-192}" \
|
| 159 |
+
COND_DIM="${COND_DIM:-64}" \
|
| 160 |
+
N_LAYERS="${N_LAYERS:-3}" \
|
| 161 |
+
N_HEADS="${N_HEADS:-3}" \
|
| 162 |
+
DIM_FF="${DIM_FF:-768}" \
|
| 163 |
+
TOTAL_STEPS="${target_step}" \
|
| 164 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE}" \
|
| 165 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE}" \
|
| 166 |
+
NUM_WORKERS="${NUM_WORKERS:-0}" \
|
| 167 |
+
LOG_EVERY="${LOG_EVERY:-100}" \
|
| 168 |
+
SAVE_EVERY="${STEP_CHUNK}" \
|
| 169 |
+
LATEST_EVERY="${STEP_CHUNK}" \
|
| 170 |
+
WARMUP_STEPS="${WARMUP_STEPS:-10}" \
|
| 171 |
+
LEARNING_RATE="${LEARNING_RATE:-0.002}" \
|
| 172 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}" \
|
| 173 |
+
MUON_IMPL="${MUON_IMPL:-legacy}" \
|
| 174 |
+
OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}" \
|
| 175 |
+
TARGET_LOSS=hard_ce \
|
| 176 |
+
MIN_MASK_RATIO=1.0 \
|
| 177 |
+
MAX_MASK_RATIO=1.0 \
|
| 178 |
+
MASK_MIXTURE_LOWK_PROB=0.0 \
|
| 179 |
+
MASK_MIXTURE_ALL_PROB=1.0 \
|
| 180 |
+
LOWK_CLEAN_TOKENS=0 \
|
| 181 |
+
CLEAN_STATE_MODE=onehot \
|
| 182 |
+
ROLLOUT_TRAIN_PROB=0.0 \
|
| 183 |
+
CATEGORICAL_WRONG_PROB_FLOOR="${floor}" \
|
| 184 |
+
RUN_NAME="${run_name}" \
|
| 185 |
+
RESUME_PATH="${resume_path}" \
|
| 186 |
+
bash scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh
|
| 187 |
+
echo "[wrong-floor] eval config=${config} step=${target_step}" | tee -a "${DRIVER_LOG}"
|
| 188 |
+
eval_latest "${config}" "${run_name}" "${target_step}" | tee -a "${DRIVER_LOG}"
|
| 189 |
+
done
|
| 190 |
+
if [[ "${active}" -eq 0 ]]; then
|
| 191 |
+
echo "[wrong-floor] all capped $(date)" | tee -a "${DRIVER_LOG}"
|
| 192 |
+
break
|
| 193 |
+
fi
|
| 194 |
+
done
|
LTA_openwebtext_dualt/scripts/watch_infer_owt_classic_fullvocab_len1024_lr2e4_gbs2048_latest_every1k_t1p45.sh
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 7 |
+
export TOKENIZERS_PARALLELISM=false
|
| 8 |
+
export PYTHONUNBUFFERED=1
|
| 9 |
+
|
| 10 |
+
# Watch the 16-GPU OWT classic full-vocab len1024/lr2e-4/GBS2048 run.
|
| 11 |
+
# The training command saves step_*.pt every 10k but latest.pt every 1k, so this
|
| 12 |
+
# watcher snapshots stable latest.pt at each new 1k step before running infer.
|
| 13 |
+
|
| 14 |
+
RUN_GLOB="${RUN_GLOB:-runs/lta_owt_classic_fullvocab_bert_c1024_len1024_lr2e4_gbs2048_2node8gpu_1m_save10k_*}"
|
| 15 |
+
RUN_DIR="${RUN_DIR:-}"
|
| 16 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/workspace/imagenet_handoff_20260327/nlp_dts_light/assets/distilbert-base-uncased/tokenizer.json}"
|
| 17 |
+
SCORER="${SCORER:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-large-standard}"
|
| 18 |
+
|
| 19 |
+
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
|
| 20 |
+
N_SAMPLES="${N_SAMPLES:-1024}"
|
| 21 |
+
STEPS="${STEPS:-128}"
|
| 22 |
+
CMAX="${CMAX:-1024}"
|
| 23 |
+
TEMP="${TEMP:-1.45}"
|
| 24 |
+
MAX_LEN="${MAX_LEN:-1024}"
|
| 25 |
+
DECODE_BATCH="${DECODE_BATCH:-1}"
|
| 26 |
+
SCORE_BATCH="${SCORE_BATCH:-1}"
|
| 27 |
+
SCORE_MAX_LENGTH="${SCORE_MAX_LENGTH:-1024}"
|
| 28 |
+
SLEEP_SECONDS="${SLEEP_SECONDS:-60}"
|
| 29 |
+
STEP_INTERVAL="${STEP_INTERVAL:-1000}"
|
| 30 |
+
DATE_TAG="${DATE_TAG:-$(date +%Y%m%d)}"
|
| 31 |
+
|
| 32 |
+
TEMP_TAG="${TEMP//./p}"
|
| 33 |
+
LOG_DIR="${LOG_DIR:-logs/owt_classic_fullvocab_len1024_lr2e4_gbs2048_infer_watch}"
|
| 34 |
+
OUT_ROOT="${OUT_ROOT:-docs/lta_samples/metrics_${DATE_TAG}/owt_classic_fullvocab_len1024_lr2e4_gbs2048_latest_every1k_normal_steps_state_t${TEMP_TAG}_c${CMAX}_n${N_SAMPLES}}"
|
| 35 |
+
|
| 36 |
+
mkdir -p "${LOG_DIR}" "${OUT_ROOT}"
|
| 37 |
+
|
| 38 |
+
find_run_dir() {
|
| 39 |
+
if [[ -n "${RUN_DIR}" ]]; then
|
| 40 |
+
if [[ -d "${RUN_DIR}" ]]; then
|
| 41 |
+
printf '%s\n' "${RUN_DIR}"
|
| 42 |
+
return 0
|
| 43 |
+
fi
|
| 44 |
+
return 1
|
| 45 |
+
fi
|
| 46 |
+
shopt -s nullglob
|
| 47 |
+
local matches=( ${RUN_GLOB} )
|
| 48 |
+
shopt -u nullglob
|
| 49 |
+
if (( ${#matches[@]} == 0 )); then
|
| 50 |
+
return 1
|
| 51 |
+
fi
|
| 52 |
+
ls -td "${matches[@]}" 2>/dev/null | head -1
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
wait_for_stable_file() {
|
| 56 |
+
local path="$1"
|
| 57 |
+
local stat_a stat_b
|
| 58 |
+
stat_a="$(stat -c '%s:%Y' "${path}" 2>/dev/null || echo missing)"
|
| 59 |
+
sleep 20
|
| 60 |
+
stat_b="$(stat -c '%s:%Y' "${path}" 2>/dev/null || echo changed)"
|
| 61 |
+
[[ "${stat_a}" == "${stat_b}" && "${stat_a}" != "missing" ]]
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
read_ckpt_step() {
|
| 65 |
+
local ckpt="$1"
|
| 66 |
+
python - "$ckpt" <<'PY'
|
| 67 |
+
import sys
|
| 68 |
+
import torch
|
| 69 |
+
ckpt = torch.load(sys.argv[1], map_location="cpu", weights_only=False)
|
| 70 |
+
step = ckpt.get("step")
|
| 71 |
+
if step is None:
|
| 72 |
+
raise SystemExit("checkpoint has no step")
|
| 73 |
+
print(int(step))
|
| 74 |
+
PY
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
echo "[watch-owt-len1024-lr2e4] run_glob=${RUN_GLOB}"
|
| 78 |
+
echo "[watch-owt-len1024-lr2e4] explicit_run_dir=${RUN_DIR:-<auto>}"
|
| 79 |
+
echo "[watch-owt-len1024-lr2e4] out_root=${OUT_ROOT}"
|
| 80 |
+
echo "[watch-owt-len1024-lr2e4] decode=normal_steps_sweep steps=${STEPS} cmax=${CMAX} temp=${TEMP} final_from=state n=${N_SAMPLES} max_len=${MAX_LEN}"
|
| 81 |
+
echo "[watch-owt-len1024-lr2e4] source=latest.pt snapshot_each=${STEP_INTERVAL} decode_batch=${DECODE_BATCH} score_batch=${SCORE_BATCH}"
|
| 82 |
+
|
| 83 |
+
while true; do
|
| 84 |
+
current_run_dir="$(find_run_dir || true)"
|
| 85 |
+
if [[ -z "${current_run_dir}" ]]; then
|
| 86 |
+
echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) waiting for matching run: ${RUN_GLOB}"
|
| 87 |
+
sleep "${SLEEP_SECONDS}"
|
| 88 |
+
continue
|
| 89 |
+
fi
|
| 90 |
+
|
| 91 |
+
run_stem="$(basename "${current_run_dir}")"
|
| 92 |
+
latest_ckpt="${current_run_dir}/latest.pt"
|
| 93 |
+
out_base="${OUT_ROOT}/${run_stem}"
|
| 94 |
+
processed_file="${LOG_DIR}/processed_${run_stem}_steps${STEPS}_c${CMAX}_t${TEMP_TAG}_n${N_SAMPLES}.txt"
|
| 95 |
+
snapshot_dir="${current_run_dir}/latest_snapshots_1k"
|
| 96 |
+
mkdir -p "${out_base}" "${LOG_DIR}" "${snapshot_dir}"
|
| 97 |
+
touch "${processed_file}"
|
| 98 |
+
|
| 99 |
+
if [[ ! -f "${latest_ckpt}" ]]; then
|
| 100 |
+
echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) run=${run_stem} no latest.pt yet"
|
| 101 |
+
sleep "${SLEEP_SECONDS}"
|
| 102 |
+
continue
|
| 103 |
+
fi
|
| 104 |
+
if ! wait_for_stable_file "${latest_ckpt}"; then
|
| 105 |
+
echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) latest.pt not stable yet"
|
| 106 |
+
sleep "${SLEEP_SECONDS}"
|
| 107 |
+
continue
|
| 108 |
+
fi
|
| 109 |
+
|
| 110 |
+
step_num="$(read_ckpt_step "${latest_ckpt}")"
|
| 111 |
+
if (( step_num <= 0 || step_num % STEP_INTERVAL != 0 )); then
|
| 112 |
+
echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) latest step=${step_num}; waiting for multiple of ${STEP_INTERVAL}"
|
| 113 |
+
sleep "${SLEEP_SECONDS}"
|
| 114 |
+
continue
|
| 115 |
+
fi
|
| 116 |
+
|
| 117 |
+
step="$(printf '%07d' "${step_num}")"
|
| 118 |
+
snapshot="${snapshot_dir}/step_${step}.pt"
|
| 119 |
+
processed_key="${current_run_dir}:step_${step}"
|
| 120 |
+
if grep -Fxq "${processed_key}" "${processed_file}"; then
|
| 121 |
+
sleep "${SLEEP_SECONDS}"
|
| 122 |
+
continue
|
| 123 |
+
fi
|
| 124 |
+
|
| 125 |
+
if [[ ! -f "${snapshot}" ]]; then
|
| 126 |
+
tmp_snapshot="${snapshot}.tmp.$$"
|
| 127 |
+
echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) snapshot latest step_${step} -> ${snapshot}"
|
| 128 |
+
cp --reflink=auto "${latest_ckpt}" "${tmp_snapshot}" 2>/dev/null || cp "${latest_ckpt}" "${tmp_snapshot}"
|
| 129 |
+
mv "${tmp_snapshot}" "${snapshot}"
|
| 130 |
+
fi
|
| 131 |
+
|
| 132 |
+
out_dir="${out_base}/step_${step}"
|
| 133 |
+
log_file="${LOG_DIR}/infer_${run_stem}_step_${step}_t${TEMP_TAG}.log"
|
| 134 |
+
mkdir -p "${out_dir}"
|
| 135 |
+
|
| 136 |
+
echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) infer ${snapshot} -> ${out_dir}" | tee -a "${log_file}"
|
| 137 |
+
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES}" python scripts/eval_owt_normal_steps_sweep_20260515.py \
|
| 138 |
+
--checkpoint "${snapshot}" \
|
| 139 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 140 |
+
--scorer "${SCORER}" \
|
| 141 |
+
--out_dir "${out_dir}" \
|
| 142 |
+
--steps_list "${STEPS}" \
|
| 143 |
+
--cmax_list "${CMAX}" \
|
| 144 |
+
--endpoint_temps "${TEMP}" \
|
| 145 |
+
--n_samples "${N_SAMPLES}" \
|
| 146 |
+
--max_len "${MAX_LEN}" \
|
| 147 |
+
--decode_batch "${DECODE_BATCH}" \
|
| 148 |
+
--score_batch "${SCORE_BATCH}" \
|
| 149 |
+
--score_max_length "${SCORE_MAX_LENGTH}" \
|
| 150 |
+
--detokenizer none \
|
| 151 |
+
--seed 20260521 \
|
| 152 |
+
--save_samples 16 \
|
| 153 |
+
2>&1 | tee -a "${log_file}"
|
| 154 |
+
|
| 155 |
+
echo "${processed_key}" >> "${processed_file}"
|
| 156 |
+
echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) done step_${step}" | tee -a "${log_file}"
|
| 157 |
+
sleep "${SLEEP_SECONDS}"
|
| 158 |
+
done
|