JinghuiLuAstronaut commited on
Commit
6bc5b2c
·
verified ·
1 Parent(s): 50fda35

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. LTA_openwebtext_dualt/logs/lta_lm1b_duo_aligned_dirichlet_true_dualtline_flmpack_onehot_hardce_ddit_small_len128_gbs512_8gpu_1m.log +0 -0
  2. LTA_openwebtext_dualt/logs/owt_fully_marginal_categorical_probe_step118k_n8.log +40 -0
  3. LTA_openwebtext_dualt/logs/softendpoint_mn_pilot_4gpu/train8_ctx1024_randk_p50_path3_rand2_3_unif0_0p25_ctx1024_uniformt_temp1_randk_20260518_010217.log +404 -0
  4. LTA_openwebtext_dualt/logs/softendpoint_mn_pilot_4gpu/train8_n1024_hard_ce_bridge_20260517_train8_overfit.log +316 -0
  5. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/hf_xet-1.5.0.dist-info/INSTALLER +1 -0
  6. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/hf_xet-1.5.0.dist-info/METADATA +87 -0
  7. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/hf_xet-1.5.0.dist-info/REQUESTED +0 -0
  8. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/typer/core.py +820 -0
  9. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/typer/py.typed +0 -0
  10. LTA_openwebtext_dualt/scripts/dirichlet_support_decode_probe.py +347 -0
  11. LTA_openwebtext_dualt/scripts/infer_softkl_decode_probe.py.bak_lognsr_gumbel_20260519 +960 -0
  12. LTA_openwebtext_dualt/scripts/launch_lta_lm1b_dualtline_cmax16_8gpu_duo_small_1m.sh +139 -0
  13. LTA_openwebtext_dualt/scripts/launch_lta_owt_gamma2_8gpu.sh +134 -0
  14. LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_adaln_adamw_wd0p1_rollin_p50_randk0_3_2node.sh +186 -0
  15. LTA_openwebtext_dualt/scripts/qwen_transformers_openai_server.py +199 -0
  16. LTA_openwebtext_dualt/scripts/run_lta_owt_bert_absrope_time4_dirichlet_len1024_Cv_to_2v_8gpu_1m_mask1_sameT_save10k.sh +70 -0
  17. LTA_openwebtext_dualt/scripts/run_train8_rollin_focused_pilots_4gpu.sh +272 -0
  18. LTA_openwebtext_dualt/scripts/run_train8_selected_long20k_4gpu.sh +228 -0
  19. LTA_openwebtext_dualt/scripts/score_lta_decode_strategy.py +144 -0
  20. LTA_openwebtext_dualt/scripts/watch_infer_lm1b_classic_c1024_every1k_t1p45.sh +92 -0
LTA_openwebtext_dualt/logs/lta_lm1b_duo_aligned_dirichlet_true_dualtline_flmpack_onehot_hardce_ddit_small_len128_gbs512_8gpu_1m.log ADDED
The diff for this file is too large to render. See raw diff
 
LTA_openwebtext_dualt/logs/owt_fully_marginal_categorical_probe_step118k_n8.log ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [forbid_endpoint_ids] n=352 first=[94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125]
2
+ [decode] steps32_c128_mtpre_t1p0_tpow1p0_noise0_state_marginal_categorical
3
+ [summary] {"name": "steps32_c128_mtpre_t1p0_tpow1p0_noise0_state_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.0, "temp_end": 1.0, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "state", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "pre", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 2.138821908505566, "distinct_1": 0.045166015625, "distinct_2": 0.16300097751710654, "top_token_mass": 0.4710693359375, "tokens_scored": 0, "readability_score": 2.701582987844995, "mean_chars": 1753.0, "replacement_chars": 0.0}
4
+ [decode] steps32_c128_mtpost_t1p0_tpow1p0_noise0_state_marginal_categorical
5
+ [summary] {"name": "steps32_c128_mtpost_t1p0_tpow1p0_noise0_state_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.0, "temp_end": 1.0, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "state", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "post", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 1.8338167258452847, "distinct_1": 0.0372314453125, "distinct_2": 0.11742424242424243, "top_token_mass": 0.5498046875, "tokens_scored": 0, "readability_score": 2.144895081574039, "mean_chars": 1478.875, "replacement_chars": 0.0}
6
+ [decode] steps32_c128_mtpre_t1p0_tpow1p0_noise0_endpoint_marginal_categorical
7
+ [summary] {"name": "steps32_c128_mtpre_t1p0_tpow1p0_noise0_endpoint_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.0, "temp_end": 1.0, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "endpoint", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "pre", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 2.138821908505566, "distinct_1": 0.045166015625, "distinct_2": 0.16300097751710654, "top_token_mass": 0.4710693359375, "tokens_scored": 0, "readability_score": 2.701582987844995, "mean_chars": 1753.0, "replacement_chars": 0.0}
8
+ [decode] steps32_c128_mtpost_t1p0_tpow1p0_noise0_endpoint_marginal_categorical
9
+ [summary] {"name": "steps32_c128_mtpost_t1p0_tpow1p0_noise0_endpoint_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.0, "temp_end": 1.0, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "endpoint", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "post", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 1.8338167258452847, "distinct_1": 0.0372314453125, "distinct_2": 0.11742424242424243, "top_token_mass": 0.5498046875, "tokens_scored": 0, "readability_score": 2.144895081574039, "mean_chars": 1478.875, "replacement_chars": 0.0}
10
+ [decode] steps32_c128_mtpre_t1p0_tpow1p0_noise0_blend_marginal_categorical
11
+ [summary] {"name": "steps32_c128_mtpre_t1p0_tpow1p0_noise0_blend_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.0, "temp_end": 1.0, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "blend", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "pre", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 2.138821908505566, "distinct_1": 0.045166015625, "distinct_2": 0.16300097751710654, "top_token_mass": 0.4710693359375, "tokens_scored": 0, "readability_score": 2.701582987844995, "mean_chars": 1753.0, "replacement_chars": 0.0}
12
+ [decode] steps32_c128_mtpost_t1p0_tpow1p0_noise0_blend_marginal_categorical
13
+ [summary] {"name": "steps32_c128_mtpost_t1p0_tpow1p0_noise0_blend_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.0, "temp_end": 1.0, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "blend", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "post", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 1.8338167258452847, "distinct_1": 0.0372314453125, "distinct_2": 0.11742424242424243, "top_token_mass": 0.5498046875, "tokens_scored": 0, "readability_score": 2.144895081574039, "mean_chars": 1478.875, "replacement_chars": 0.0}
14
+ [decode] steps32_c128_mtpre_t1p1_tpow1p0_noise0_state_marginal_categorical
15
+ [summary] {"name": "steps32_c128_mtpre_t1p1_tpow1p0_noise0_state_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.1, "temp_end": 1.1, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "state", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "pre", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 2.9428859927624273, "distinct_1": 0.0694580078125, "distinct_2": 0.30584066471163246, "top_token_mass": 0.3778076171875, "tokens_scored": 0, "readability_score": 3.9354258205976476, "mean_chars": 2241.25, "replacement_chars": 0.0}
16
+ [decode] steps32_c128_mtpost_t1p1_tpow1p0_noise0_state_marginal_categorical
17
+ [summary] {"name": "steps32_c128_mtpost_t1p1_tpow1p0_noise0_state_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.1, "temp_end": 1.1, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "state", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "post", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 2.4547148052858425, "distinct_1": 0.06396484375, "distinct_2": 0.2371700879765396, "top_token_mass": 0.4598388671875, "tokens_scored": 0, "readability_score": 3.661850259272684, "mean_chars": 1802.75, "replacement_chars": 0.0}
18
+ [decode] steps32_c128_mtpre_t1p1_tpow1p0_noise0_endpoint_marginal_categorical
19
+ [summary] {"name": "steps32_c128_mtpre_t1p1_tpow1p0_noise0_endpoint_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.1, "temp_end": 1.1, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "endpoint", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "pre", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 2.9428859927624273, "distinct_1": 0.0694580078125, "distinct_2": 0.30584066471163246, "top_token_mass": 0.3778076171875, "tokens_scored": 0, "readability_score": 3.9354258205976476, "mean_chars": 2241.25, "replacement_chars": 0.0}
20
+ [decode] steps32_c128_mtpost_t1p1_tpow1p0_noise0_endpoint_marginal_categorical
21
+ [summary] {"name": "steps32_c128_mtpost_t1p1_tpow1p0_noise0_endpoint_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.1, "temp_end": 1.1, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "endpoint", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "post", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 2.4547148052858425, "distinct_1": 0.06396484375, "distinct_2": 0.2371700879765396, "top_token_mass": 0.4598388671875, "tokens_scored": 0, "readability_score": 3.661850259272684, "mean_chars": 1802.75, "replacement_chars": 0.0}
22
+ [decode] steps32_c128_mtpre_t1p1_tpow1p0_noise0_blend_marginal_categorical
23
+ [summary] {"name": "steps32_c128_mtpre_t1p1_tpow1p0_noise0_blend_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.1, "temp_end": 1.1, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "blend", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "pre", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 2.9428859927624273, "distinct_1": 0.0694580078125, "distinct_2": 0.30584066471163246, "top_token_mass": 0.3778076171875, "tokens_scored": 0, "readability_score": 3.9354258205976476, "mean_chars": 2241.25, "replacement_chars": 0.0}
24
+ [decode] steps32_c128_mtpost_t1p1_tpow1p0_noise0_blend_marginal_categorical
25
+ [summary] {"name": "steps32_c128_mtpost_t1p1_tpow1p0_noise0_blend_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.1, "temp_end": 1.1, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "blend", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "post", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 2.4547148052858425, "distinct_1": 0.06396484375, "distinct_2": 0.2371700879765396, "top_token_mass": 0.4598388671875, "tokens_scored": 0, "readability_score": 3.661850259272684, "mean_chars": 1802.75, "replacement_chars": 0.0}
26
+ [decode] steps32_c128_mtpre_t1p2_tpow1p0_noise0_state_marginal_categorical
27
+ [summary] {"name": "steps32_c128_mtpre_t1p2_tpow1p0_noise0_state_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.2, "temp_end": 1.2, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "state", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "pre", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 3.9514859184602247, "distinct_1": 0.102294921875, "distinct_2": 0.49108015640273706, "top_token_mass": 0.2208251953125, "tokens_scored": 0, "readability_score": 4.272469777249244, "mean_chars": 3242.625, "replacement_chars": 0.0}
28
+ [decode] steps32_c128_mtpost_t1p2_tpow1p0_noise0_state_marginal_categorical
29
+ [summary] {"name": "steps32_c128_mtpost_t1p2_tpow1p0_noise0_state_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.2, "temp_end": 1.2, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "state", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "post", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 3.735299917945654, "distinct_1": 0.1103515625, "distinct_2": 0.4671309872922776, "top_token_mass": 0.20361328125, "tokens_scored": 0, "readability_score": 4.142991364470012, "mean_chars": 2971.625, "replacement_chars": 0.0}
30
+ [decode] steps32_c128_mtpre_t1p2_tpow1p0_noise0_endpoint_marginal_categorical
31
+ [summary] {"name": "steps32_c128_mtpre_t1p2_tpow1p0_noise0_endpoint_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.2, "temp_end": 1.2, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "endpoint", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "pre", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 3.9514859184602247, "distinct_1": 0.102294921875, "distinct_2": 0.49108015640273706, "top_token_mass": 0.2208251953125, "tokens_scored": 0, "readability_score": 4.272469777249244, "mean_chars": 3242.625, "replacement_chars": 0.0}
32
+ [decode] steps32_c128_mtpost_t1p2_tpow1p0_noise0_endpoint_marginal_categorical
33
+ [summary] {"name": "steps32_c128_mtpost_t1p2_tpow1p0_noise0_endpoint_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.2, "temp_end": 1.2, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "endpoint", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "post", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 3.735299917945654, "distinct_1": 0.1103515625, "distinct_2": 0.4671309872922776, "top_token_mass": 0.20361328125, "tokens_scored": 0, "readability_score": 4.142991364470012, "mean_chars": 2971.625, "replacement_chars": 0.0}
34
+ [decode] steps32_c128_mtpre_t1p2_tpow1p0_noise0_blend_marginal_categorical
35
+ [summary] {"name": "steps32_c128_mtpre_t1p2_tpow1p0_noise0_blend_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.2, "temp_end": 1.2, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "blend", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "pre", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 3.9514859184602247, "distinct_1": 0.102294921875, "distinct_2": 0.49108015640273706, "top_token_mass": 0.2208251953125, "tokens_scored": 0, "readability_score": 4.272469777249244, "mean_chars": 3242.625, "replacement_chars": 0.0}
36
+ [decode] steps32_c128_mtpost_t1p2_tpow1p0_noise0_blend_marginal_categorical
37
+ [summary] {"name": "steps32_c128_mtpost_t1p2_tpow1p0_noise0_blend_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.2, "temp_end": 1.2, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "blend", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "post", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 3.735299917945654, "distinct_1": 0.1103515625, "distinct_2": 0.4671309872922776, "top_token_mass": 0.20361328125, "tokens_scored": 0, "readability_score": 4.142991364470012, "mean_chars": 2971.625, "replacement_chars": 0.0}
38
+ [decode] steps32_c128_mtpre_t1p3_tpow1p0_noise0_state_marginal_categorical
39
+ [summary] {"name": "steps32_c128_mtpre_t1p3_tpow1p0_noise0_state_marginal_categorical", "step": 118000, "n_samples": 8, "steps": 32, "concentration_max": 128.0, "temp_start": 1.3, "temp_end": 1.3, "temp_schedule": "const", "t_power": 1.0, "eta0": 0.0, "eta_schedule": "none", "noise_conc": 1.0, "final_from": "state", "final_decode": "argmax", "final_temp": 1.0, "final_top_k": 0, "final_uncertain_threshold": 0.85, "update_rule": "marginal_categorical", "model_t_mode": "pre", "lock_bos": true, "lock_final_eos": false, "detok_genppl": NaN, "sample_entropy": 4.465642952565066, "distinct_1": 0.117919921875, "distinct_2": 0.5981182795698925, "top_token_mass": 0.0947265625, "tokens_scored": 0, "readability_score": 4.276690649955215, "mean_chars": 4014.75, "replacement_chars": 0.0}
40
+ [decode] steps32_c128_mtpost_t1p3_tpow1p0_noise0_state_marginal_categorical
LTA_openwebtext_dualt/logs/softendpoint_mn_pilot_4gpu/train8_ctx1024_randk_p50_path3_rand2_3_unif0_0p25_ctx1024_uniformt_temp1_randk_20260518_010217.log ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NCCL version 2.25.1+cuda12.8
2
+ {
3
+ "device": "cuda:0",
4
+ "rank": 0,
5
+ "world_size": 4,
6
+ "samples": "owt_cached_chunks:8",
7
+ "vocab_size": 2664,
8
+ "tokenizer_vocab_size": 50257,
9
+ "save_dir": "runs/train8_ctx1024_randk_p50_path3_rand2_3_unif0_0p25_ctx1024_uniformt_temp1_randk_20260518_010217",
10
+ "batch_size": 128,
11
+ "grad_accum": 1,
12
+ "effective_batch_size": 512,
13
+ "global_batch_size": 512,
14
+ "lr_schedule": "constant_warmup",
15
+ "optimizer": "muon",
16
+ "epochs": 0.0,
17
+ "steps_per_epoch": 1,
18
+ "total_steps": 1000,
19
+ "warmup_steps": 10,
20
+ "warmup_epochs": -1.0,
21
+ "min_lr": 0.0,
22
+ "weight_decay": 0.1,
23
+ "output_weight_decay": -1.0,
24
+ "adamw_param_groups": "nanogpt",
25
+ "adam_beta1": 0.9,
26
+ "adam_beta2": 0.95,
27
+ "adam_eps": 1e-08,
28
+ "muon_impl": "legacy",
29
+ "muon_momentum": 0.95,
30
+ "muon_ns_steps": 5,
31
+ "muon_update_scale": 1.0,
32
+ "muon_nesterov": false,
33
+ "muon_width_scale": false,
34
+ "muon_grouping": "legacy_dim_ge_2",
35
+ "muon_param_count": 2616320,
36
+ "muon_adam_param_count": 8192,
37
+ "muon_param_names": [
38
+ "vocab_embed.embedding",
39
+ "sigma_map.net.0.weight",
40
+ "sigma_map.net.2.weight",
41
+ "blocks.0.attn_qkv.weight",
42
+ "blocks.0.attn_out.weight",
43
+ "blocks.0.mlp.0.weight",
44
+ "blocks.0.mlp.2.weight",
45
+ "blocks.0.adaLN_modulation.weight",
46
+ "blocks.1.attn_qkv.weight",
47
+ "blocks.1.attn_out.weight",
48
+ "blocks.1.mlp.0.weight",
49
+ "blocks.1.mlp.2.weight",
50
+ "blocks.1.adaLN_modulation.weight",
51
+ "blocks.2.attn_qkv.weight",
52
+ "blocks.2.attn_out.weight",
53
+ "blocks.2.mlp.0.weight",
54
+ "blocks.2.mlp.2.weight",
55
+ "blocks.2.adaLN_modulation.weight",
56
+ "output_layer.linear.weight",
57
+ "output_layer.adaLN_modulation.weight"
58
+ ],
59
+ "muon_adam_param_names": [
60
+ "sigma_map.net.0.bias",
61
+ "sigma_map.net.2.bias",
62
+ "blocks.0.norm1.weight",
63
+ "blocks.0.norm2.weight",
64
+ "blocks.0.mlp.0.bias",
65
+ "blocks.0.mlp.2.bias",
66
+ "blocks.0.adaLN_modulation.bias",
67
+ "blocks.1.norm1.weight",
68
+ "blocks.1.norm2.weight",
69
+ "blocks.1.mlp.0.bias",
70
+ "blocks.1.mlp.2.bias",
71
+ "blocks.1.adaLN_modulation.bias",
72
+ "blocks.2.norm1.weight",
73
+ "blocks.2.norm2.weight",
74
+ "blocks.2.mlp.0.bias",
75
+ "blocks.2.mlp.2.bias",
76
+ "blocks.2.adaLN_modulation.bias",
77
+ "output_layer.norm_final.weight",
78
+ "output_layer.adaLN_modulation.bias"
79
+ ],
80
+ "muon_effective_nesterov": false,
81
+ "muon_effective_width_scale": false,
82
+ "muon_effective_weight_decay": 0.1,
83
+ "muon_adam_fallback_nesterov": false,
84
+ "muon_adam_fallback_weight_decay": 0.1,
85
+ "ema_decay": 0.9999,
86
+ "ema_start_step": 0,
87
+ "model_type": "ddit",
88
+ "ddit_mlp_type": "gelu",
89
+ "elf_num_time_tokens": 4,
90
+ "elf_num_model_mode_tokens": 0,
91
+ "qk_norm": true,
92
+ "output_bias": false,
93
+ "output_init_std": -1.0,
94
+ "norm_type": "rmsnorm",
95
+ "target_loss": "hard_ce",
96
+ "linear_soft_target_power": 1.0,
97
+ "linear_soft_target_min_conf": 0.0,
98
+ "linear_soft_target_max_conf": 1.0,
99
+ "t_sampling_mode": "uniform",
100
+ "t_sampling_power": 1.0,
101
+ "t_sampling_eps": 0.0001,
102
+ "t_sampling_logit_mean": -1.5,
103
+ "t_sampling_logit_std": 0.8,
104
+ "dual_t": true,
105
+ "corrupt_t_mode": "same",
106
+ "corrupt_min_t": 0.0,
107
+ "corrupt_max_t": 1.0,
108
+ "prefix_block_prob": 0.0,
109
+ "prefix_block_len": 128,
110
+ "mask_ratio_floor_schedule": "none",
111
+ "dirichlet_endpoint_mode": "categorical_dual_t",
112
+ "dirichlet_semantic_t_mode": "same",
113
+ "dirichlet_semantic_t_value": 0.0,
114
+ "dirichlet_semantic_t_curve": "linear",
115
+ "dirichlet_semantic_t_power": 1.0,
116
+ "endpoint_sequence_random_prob_alpha": 0.0,
117
+ "categorical_wrong_from_full_vocab": true,
118
+ "categorical_wrong_from_batch_valid_tokens": false,
119
+ "categorical_wrong_basin_token_ids": "",
120
+ "categorical_wrong_basin_prob": 0.0,
121
+ "categorical_wrong_unigram_prob": 0.0,
122
+ "categorical_wrong_uniform_prob": 0.0,
123
+ "categorical_wrong_prob_floor": 0.0,
124
+ "categorical_wrong_corpus_unigram_path": "",
125
+ "categorical_wrong_corpus_unigram_alpha": 1.0,
126
+ "categorical_wrong_basin_shared_prob": 0.0,
127
+ "categorical_wrong_unigram_shared_prob": 0.0,
128
+ "mask_mixture_original_prob": 0.0,
129
+ "mask_mixture_lowk_prob": 0.0,
130
+ "mask_mixture_lowcorrupt_prob": 0.0,
131
+ "mask_mixture_block_prob": 0.0,
132
+ "mask_mixture_all_prob": 1.0,
133
+ "mask_mixture_lowk_clean_tokens": "0",
134
+ "mask_mixture_lowcorrupt_tokens": "1,2,4,8,16,32,64",
135
+ "mask_mixture_block_tokens": "64,128",
136
+ "simplex_bridge_sampler": "dirichlet",
137
+ "logistic_normal_sigma_min": 0.1,
138
+ "logistic_normal_sigma_max": 1.0,
139
+ "logistic_normal_tau_min": 1.0,
140
+ "logistic_normal_tau_max": 1.0,
141
+ "torch_compile": false,
142
+ "compile_mode": "max-autotune",
143
+ "state_format": "prob",
144
+ "meanflow_weight": 0.0,
145
+ "rollout_train_prob": 0.5,
146
+ "rollout_train_steps": 3,
147
+ "rollout_train_steps_min": 2,
148
+ "rollout_train_infer_steps": 1,
149
+ "rollout_train_time_mode": "sampled_path",
150
+ "rollout_train_s_dist": "uniform",
151
+ "rollout_train_s_min_frac": 0.0,
152
+ "rollout_train_s_max_frac": 0.25,
153
+ "rollout_train_s_beta_alpha": 2.0,
154
+ "rollout_train_s_beta_beta": 6.0,
155
+ "rollout_train_temp": 1.0,
156
+ "rollout_train_max_gamma": 1.0,
157
+ "rollout_train_corrupt_only": true,
158
+ "rollout_train_samplewise": true,
159
+ "rollout_train_compute_always": false,
160
+ "rollout_train_sync_t": true,
161
+ "bridge_noise_init": "logistic_normal",
162
+ "noise_sigma": -1.0,
163
+ "allow_tf32": true,
164
+ "activation_checkpointing": false,
165
+ "activation_checkpoint_interval": 1,
166
+ "activation_checkpoint_scope": "block",
167
+ "ddp_static_graph": false,
168
+ "ddp_gradient_as_bucket_view": true,
169
+ "blocking_data_transfer": false,
170
+ "dataloader_prefetch_factor": 4,
171
+ "full_train_stats": false,
172
+ "tokenized_hf": false,
173
+ "tokenized_pad_token": "pad",
174
+ "elf_conditional_hf": false,
175
+ "record_pad_truncate": false,
176
+ "record_add_eos": false,
177
+ "record_add_special_tokens": false,
178
+ "record_pad_token": "pad",
179
+ "record_shuffle_buffer": 10000,
180
+ "wrap": true,
181
+ "wrap_mode": "stream",
182
+ "wrap_record_buffer_size": 200,
183
+ "owt_cached_chunks": true,
184
+ "owt_chunk_cache_dir": "/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train8_compact_overfit",
185
+ "owt_chunk_cache_rebuild": false,
186
+ "owt_chunk_cache_write_batch": 4096,
187
+ "owt_exact_repeat_per_chunk": 64,
188
+ "online_chunk_shuffle": false,
189
+ "online_chunk_shuffle_buffer": 10000,
190
+ "openwebtext_split": "train_minus_100k",
191
+ "detokenizer": "auto",
192
+ "resolved_detokenizer": null,
193
+ "num_workers": 0,
194
+ "latest_every": 1000,
195
+ "resume_path": ""
196
+ }
197
+ step=100 epoch=100/1000 epoch_step=1/1 micro_steps=100 elapsed=22.1s lr=2.000000e-03 loss=7.4577 loss_recon=7.4577 loss_meanflow=0.0000 mean_model_t=0.4971 mean_corrupt_t=0.4971 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.5052 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.3139 corrupt_frac=1.0000 acc_corrupt=0.3139 loss_corrupt=7.4577 wrong_frac=0.5028 init_acc_corrupt=0.4627 acc_corrupt_t_0p0_0p2=0.0440 corrupt_frac_t_0p0_0p2=0.2071 acc_corrupt_t_0p2_0p4=0.1523 corrupt_frac_t_0p2_0p4=0.1988 acc_corrupt_t_0p4_0p6=0.3124 corrupt_frac_t_0p4_0p6=0.1976 acc_corrupt_t_0p6_0p8=0.4627 corrupt_frac_t_0p6_0p8=0.1951 acc_corrupt_t_0p8_1p0=0.6082 corrupt_frac_t_0p8_1p0=0.2015 out_w_norm=1.0709 out_g_norm=0.9530 loss_all=6.9234 init_gold_top10=0.4680 init_gold_top100=0.5774 rollout_applied_pos_frac=0.4688 init_acc_rollout_applied=0.4067 init_acc_rollout_kept=0.4381 logit_acc_rollout_applied=0.2671 logit_acc_rollout_kept=0.2807
198
+ step=200 epoch=200/1000 epoch_step=1/1 micro_steps=200 elapsed=21.1s lr=2.000000e-03 loss=6.0057 loss_recon=6.0057 loss_meanflow=0.0000 mean_model_t=0.5015 mean_corrupt_t=0.5015 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.5026 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.3125 corrupt_frac=1.0000 acc_corrupt=0.3125 loss_corrupt=6.0057 wrong_frac=0.4984 init_acc_corrupt=0.4680 acc_corrupt_t_0p0_0p2=0.0494 corrupt_frac_t_0p0_0p2=0.2005 acc_corrupt_t_0p2_0p4=0.1494 corrupt_frac_t_0p2_0p4=0.1979 acc_corrupt_t_0p4_0p6=0.3107 corrupt_frac_t_0p4_0p6=0.1994 acc_corrupt_t_0p6_0p8=0.4530 corrupt_frac_t_0p6_0p8=0.2007 acc_corrupt_t_0p8_1p0=0.5961 corrupt_frac_t_0p8_1p0=0.2015 out_w_norm=3.4215 out_g_norm=1.2538 loss_all=5.2425 init_gold_top10=0.5104 init_gold_top100=0.6209 rollout_applied_pos_frac=0.5391 init_acc_rollout_applied=0.4934 init_acc_rollout_kept=0.4670 logit_acc_rollout_applied=0.3498 logit_acc_rollout_kept=0.3434
199
+ step=300 epoch=300/1000 epoch_step=1/1 micro_steps=300 elapsed=21.2s lr=2.000000e-03 loss=4.9295 loss_recon=4.9295 loss_meanflow=0.0000 mean_model_t=0.5015 mean_corrupt_t=0.5015 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.5045 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.3538 corrupt_frac=1.0000 acc_corrupt=0.3538 loss_corrupt=4.9295 wrong_frac=0.4985 init_acc_corrupt=0.4693 acc_corrupt_t_0p0_0p2=0.0519 corrupt_frac_t_0p0_0p2=0.1947 acc_corrupt_t_0p2_0p4=0.1777 corrupt_frac_t_0p2_0p4=0.2030 acc_corrupt_t_0p4_0p6=0.3483 corrupt_frac_t_0p4_0p6=0.1991 acc_corrupt_t_0p6_0p8=0.5083 corrupt_frac_t_0p6_0p8=0.2045 acc_corrupt_t_0p8_1p0=0.6758 corrupt_frac_t_0p8_1p0=0.1988 out_w_norm=5.5236 out_g_norm=0.5180 loss_all=4.5734 init_gold_top10=0.5207 init_gold_top100=0.6348 rollout_applied_pos_frac=0.5391 init_acc_rollout_applied=0.4713 init_acc_rollout_kept=0.4756 logit_acc_rollout_applied=0.3741 logit_acc_rollout_kept=0.3869
200
+ step=400 epoch=400/1000 epoch_step=1/1 micro_steps=400 elapsed=21.1s lr=2.000000e-03 loss=4.2910 loss_recon=4.2910 loss_meanflow=0.0000 mean_model_t=0.4986 mean_corrupt_t=0.4986 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.4976 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.4064 corrupt_frac=1.0000 acc_corrupt=0.4064 loss_corrupt=4.2910 wrong_frac=0.5016 init_acc_corrupt=0.4661 acc_corrupt_t_0p0_0p2=0.0554 corrupt_frac_t_0p0_0p2=0.2005 acc_corrupt_t_0p2_0p4=0.2003 corrupt_frac_t_0p2_0p4=0.2001 acc_corrupt_t_0p4_0p6=0.4018 corrupt_frac_t_0p4_0p6=0.2023 acc_corrupt_t_0p6_0p8=0.5897 corrupt_frac_t_0p6_0p8=0.1992 acc_corrupt_t_0p8_1p0=0.7905 corrupt_frac_t_0p8_1p0=0.1979 out_w_norm=7.1059 out_g_norm=0.2729 loss_all=4.1041 init_gold_top10=0.4976 init_gold_top100=0.6342 rollout_applied_pos_frac=0.5469 init_acc_rollout_applied=0.4729 init_acc_rollout_kept=0.4091 logit_acc_rollout_applied=0.4558 logit_acc_rollout_kept=0.4033
201
+ step=500 epoch=500/1000 epoch_step=1/1 micro_steps=500 elapsed=21.1s lr=2.000000e-03 loss=3.6585 loss_recon=3.6585 loss_meanflow=0.0000 mean_model_t=0.4977 mean_corrupt_t=0.4977 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.4981 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.4771 corrupt_frac=1.0000 acc_corrupt=0.4771 loss_corrupt=3.6585 wrong_frac=0.5023 init_acc_corrupt=0.4655 acc_corrupt_t_0p0_0p2=0.0565 corrupt_frac_t_0p0_0p2=0.2059 acc_corrupt_t_0p2_0p4=0.2322 corrupt_frac_t_0p2_0p4=0.1985 acc_corrupt_t_0p4_0p6=0.4947 corrupt_frac_t_0p4_0p6=0.1969 acc_corrupt_t_0p6_0p8=0.7063 corrupt_frac_t_0p6_0p8=0.1927 acc_corrupt_t_0p8_1p0=0.9024 corrupt_frac_t_0p8_1p0=0.2059 out_w_norm=8.4366 out_g_norm=0.2530 loss_all=3.4213 init_gold_top10=0.5026 init_gold_top100=0.6556 rollout_applied_pos_frac=0.5234 init_acc_rollout_applied=0.4297 init_acc_rollout_kept=0.4792 logit_acc_rollout_applied=0.4470 logit_acc_rollout_kept=0.5017
202
+ step=600 epoch=600/1000 epoch_step=1/1 micro_steps=600 elapsed=21.2s lr=2.000000e-03 loss=3.1426 loss_recon=3.1426 loss_meanflow=0.0000 mean_model_t=0.5012 mean_corrupt_t=0.5012 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.5030 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.4931 corrupt_frac=1.0000 acc_corrupt=0.4931 loss_corrupt=3.1426 wrong_frac=0.4987 init_acc_corrupt=0.4696 acc_corrupt_t_0p0_0p2=0.0591 corrupt_frac_t_0p0_0p2=0.1972 acc_corrupt_t_0p2_0p4=0.2633 corrupt_frac_t_0p2_0p4=0.2024 acc_corrupt_t_0p4_0p6=0.5212 corrupt_frac_t_0p4_0p6=0.1990 acc_corrupt_t_0p6_0p8=0.7135 corrupt_frac_t_0p6_0p8=0.2017 acc_corrupt_t_0p8_1p0=0.9043 corrupt_frac_t_0p8_1p0=0.1997 out_w_norm=9.7029 out_g_norm=0.2761 loss_all=2.5311 init_gold_top10=0.5979 init_gold_top100=0.7268 rollout_applied_pos_frac=0.4922 init_acc_rollout_applied=0.5201 init_acc_rollout_kept=0.5812 logit_acc_rollout_applied=0.5417 logit_acc_rollout_kept=0.6050
203
+ step=700 epoch=700/1000 epoch_step=1/1 micro_steps=700 elapsed=21.0s lr=2.000000e-03 loss=2.8141 loss_recon=2.8141 loss_meanflow=0.0000 mean_model_t=0.5008 mean_corrupt_t=0.5008 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.4924 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.4998 corrupt_frac=1.0000 acc_corrupt=0.4998 loss_corrupt=2.8141 wrong_frac=0.4994 init_acc_corrupt=0.4693 acc_corrupt_t_0p0_0p2=0.0609 corrupt_frac_t_0p0_0p2=0.2030 acc_corrupt_t_0p2_0p4=0.2757 corrupt_frac_t_0p2_0p4=0.1915 acc_corrupt_t_0p4_0p6=0.5280 corrupt_frac_t_0p4_0p6=0.2023 acc_corrupt_t_0p6_0p8=0.7194 corrupt_frac_t_0p6_0p8=0.2001 acc_corrupt_t_0p8_1p0=0.9050 corrupt_frac_t_0p8_1p0=0.2032 out_w_norm=10.7186 out_g_norm=0.3014 loss_all=2.6805 init_gold_top10=0.5397 init_gold_top100=0.7218 rollout_applied_pos_frac=0.5000 init_acc_rollout_applied=0.4357 init_acc_rollout_kept=0.5183 logit_acc_rollout_applied=0.4688 logit_acc_rollout_kept=0.5495
204
+ step=800 epoch=800/1000 epoch_step=1/1 micro_steps=800 elapsed=21.1s lr=2.000000e-03 loss=2.3910 loss_recon=2.3910 loss_meanflow=0.0000 mean_model_t=0.4962 mean_corrupt_t=0.4962 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.4930 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.5223 corrupt_frac=1.0000 acc_corrupt=0.5223 loss_corrupt=2.3910 wrong_frac=0.5037 init_acc_corrupt=0.4653 acc_corrupt_t_0p0_0p2=0.0621 corrupt_frac_t_0p0_0p2=0.2026 acc_corrupt_t_0p2_0p4=0.3034 corrupt_frac_t_0p2_0p4=0.2020 acc_corrupt_t_0p4_0p6=0.5729 corrupt_frac_t_0p4_0p6=0.2005 acc_corrupt_t_0p6_0p8=0.7638 corrupt_frac_t_0p6_0p8=0.1965 acc_corrupt_t_0p8_1p0=0.9245 corrupt_frac_t_0p8_1p0=0.1984 out_w_norm=11.2253 out_g_norm=0.3916 loss_all=2.2693 init_gold_top10=0.5604 init_gold_top100=0.7066 rollout_applied_pos_frac=0.4609 init_acc_rollout_applied=0.5007 init_acc_rollout_kept=0.4082 logit_acc_rollout_applied=0.5893 logit_acc_rollout_kept=0.4892
205
+ step=900 epoch=900/1000 epoch_step=1/1 micro_steps=900 elapsed=21.2s lr=2.000000e-03 loss=1.8221 loss_recon=1.8221 loss_meanflow=0.0000 mean_model_t=0.5040 mean_corrupt_t=0.5040 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.5050 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.6121 corrupt_frac=1.0000 acc_corrupt=0.6121 loss_corrupt=1.8221 wrong_frac=0.4960 init_acc_corrupt=0.4781 acc_corrupt_t_0p0_0p2=0.0648 corrupt_frac_t_0p0_0p2=0.1931 acc_corrupt_t_0p2_0p4=0.3866 corrupt_frac_t_0p2_0p4=0.2013 acc_corrupt_t_0p4_0p6=0.7303 corrupt_frac_t_0p4_0p6=0.2019 acc_corrupt_t_0p6_0p8=0.8832 corrupt_frac_t_0p6_0p8=0.1981 acc_corrupt_t_0p8_1p0=0.9695 corrupt_frac_t_0p8_1p0=0.2056 out_w_norm=11.6309 out_g_norm=0.4861 loss_all=1.5186 init_gold_top10=0.6299 init_gold_top100=0.7419 rollout_applied_pos_frac=0.4297 init_acc_rollout_applied=0.4838 init_acc_rollout_kept=0.5220 logit_acc_rollout_applied=0.6371 logit_acc_rollout_kept=0.6940
206
+ step=1000 epoch=1000/1000 epoch_step=1/1 micro_steps=1000 elapsed=21.2s lr=2.000000e-03 loss=1.4473 loss_recon=1.4473 loss_meanflow=0.0000 mean_model_t=0.5008 mean_corrupt_t=0.5008 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.4999 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.6901 corrupt_frac=1.0000 acc_corrupt=0.6901 loss_corrupt=1.4473 wrong_frac=0.4993 init_acc_corrupt=0.4866 acc_corrupt_t_0p0_0p2=0.0746 corrupt_frac_t_0p0_0p2=0.2034 acc_corrupt_t_0p2_0p4=0.5408 corrupt_frac_t_0p2_0p4=0.1967 acc_corrupt_t_0p4_0p6=0.8814 corrupt_frac_t_0p4_0p6=0.2010 acc_corrupt_t_0p6_0p8=0.9673 corrupt_frac_t_0p6_0p8=0.1933 acc_corrupt_t_0p8_1p0=0.9943 corrupt_frac_t_0p8_1p0=0.2055 out_w_norm=11.9531 out_g_norm=0.5594 loss_all=1.4882 init_gold_top10=0.6393 init_gold_top100=0.7602 rollout_applied_pos_frac=0.5547 init_acc_rollout_applied=0.5451 init_acc_rollout_kept=0.4294 logit_acc_rollout_applied=0.7345 logit_acc_rollout_kept=0.6450
207
+ NCCL version 2.25.1+cuda12.8
208
+ resumed_from=runs/train8_ctx1024_randk_p50_path3_rand2_3_unif0_0p25_ctx1024_uniformt_temp1_randk_20260518_010217/latest.pt start_step=1001
209
+ {
210
+ "device": "cuda:0",
211
+ "rank": 0,
212
+ "world_size": 4,
213
+ "samples": "owt_cached_chunks:8",
214
+ "vocab_size": 2664,
215
+ "tokenizer_vocab_size": 50257,
216
+ "save_dir": "runs/train8_ctx1024_randk_p50_path3_rand2_3_unif0_0p25_ctx1024_uniformt_temp1_randk_20260518_010217",
217
+ "batch_size": 128,
218
+ "grad_accum": 1,
219
+ "effective_batch_size": 512,
220
+ "global_batch_size": 512,
221
+ "lr_schedule": "constant_warmup",
222
+ "optimizer": "muon",
223
+ "epochs": 0.0,
224
+ "steps_per_epoch": 1,
225
+ "total_steps": 2000,
226
+ "warmup_steps": 10,
227
+ "warmup_epochs": -1.0,
228
+ "min_lr": 0.0,
229
+ "weight_decay": 0.1,
230
+ "output_weight_decay": -1.0,
231
+ "adamw_param_groups": "nanogpt",
232
+ "adam_beta1": 0.9,
233
+ "adam_beta2": 0.95,
234
+ "adam_eps": 1e-08,
235
+ "muon_impl": "legacy",
236
+ "muon_momentum": 0.95,
237
+ "muon_ns_steps": 5,
238
+ "muon_update_scale": 1.0,
239
+ "muon_nesterov": false,
240
+ "muon_width_scale": false,
241
+ "muon_grouping": "legacy_dim_ge_2",
242
+ "muon_param_count": 2616320,
243
+ "muon_adam_param_count": 8192,
244
+ "muon_param_names": [
245
+ "vocab_embed.embedding",
246
+ "sigma_map.net.0.weight",
247
+ "sigma_map.net.2.weight",
248
+ "blocks.0.attn_qkv.weight",
249
+ "blocks.0.attn_out.weight",
250
+ "blocks.0.mlp.0.weight",
251
+ "blocks.0.mlp.2.weight",
252
+ "blocks.0.adaLN_modulation.weight",
253
+ "blocks.1.attn_qkv.weight",
254
+ "blocks.1.attn_out.weight",
255
+ "blocks.1.mlp.0.weight",
256
+ "blocks.1.mlp.2.weight",
257
+ "blocks.1.adaLN_modulation.weight",
258
+ "blocks.2.attn_qkv.weight",
259
+ "blocks.2.attn_out.weight",
260
+ "blocks.2.mlp.0.weight",
261
+ "blocks.2.mlp.2.weight",
262
+ "blocks.2.adaLN_modulation.weight",
263
+ "output_layer.linear.weight",
264
+ "output_layer.adaLN_modulation.weight"
265
+ ],
266
+ "muon_adam_param_names": [
267
+ "sigma_map.net.0.bias",
268
+ "sigma_map.net.2.bias",
269
+ "blocks.0.norm1.weight",
270
+ "blocks.0.norm2.weight",
271
+ "blocks.0.mlp.0.bias",
272
+ "blocks.0.mlp.2.bias",
273
+ "blocks.0.adaLN_modulation.bias",
274
+ "blocks.1.norm1.weight",
275
+ "blocks.1.norm2.weight",
276
+ "blocks.1.mlp.0.bias",
277
+ "blocks.1.mlp.2.bias",
278
+ "blocks.1.adaLN_modulation.bias",
279
+ "blocks.2.norm1.weight",
280
+ "blocks.2.norm2.weight",
281
+ "blocks.2.mlp.0.bias",
282
+ "blocks.2.mlp.2.bias",
283
+ "blocks.2.adaLN_modulation.bias",
284
+ "output_layer.norm_final.weight",
285
+ "output_layer.adaLN_modulation.bias"
286
+ ],
287
+ "muon_effective_nesterov": false,
288
+ "muon_effective_width_scale": false,
289
+ "muon_effective_weight_decay": 0.1,
290
+ "muon_adam_fallback_nesterov": false,
291
+ "muon_adam_fallback_weight_decay": 0.1,
292
+ "ema_decay": 0.9999,
293
+ "ema_start_step": 0,
294
+ "model_type": "ddit",
295
+ "ddit_mlp_type": "gelu",
296
+ "elf_num_time_tokens": 4,
297
+ "elf_num_model_mode_tokens": 0,
298
+ "qk_norm": true,
299
+ "output_bias": false,
300
+ "output_init_std": -1.0,
301
+ "norm_type": "rmsnorm",
302
+ "target_loss": "hard_ce",
303
+ "linear_soft_target_power": 1.0,
304
+ "linear_soft_target_min_conf": 0.0,
305
+ "linear_soft_target_max_conf": 1.0,
306
+ "t_sampling_mode": "uniform",
307
+ "t_sampling_power": 1.0,
308
+ "t_sampling_eps": 0.0001,
309
+ "t_sampling_logit_mean": -1.5,
310
+ "t_sampling_logit_std": 0.8,
311
+ "dual_t": true,
312
+ "corrupt_t_mode": "same",
313
+ "corrupt_min_t": 0.0,
314
+ "corrupt_max_t": 1.0,
315
+ "prefix_block_prob": 0.0,
316
+ "prefix_block_len": 128,
317
+ "mask_ratio_floor_schedule": "none",
318
+ "dirichlet_endpoint_mode": "categorical_dual_t",
319
+ "dirichlet_semantic_t_mode": "same",
320
+ "dirichlet_semantic_t_value": 0.0,
321
+ "dirichlet_semantic_t_curve": "linear",
322
+ "dirichlet_semantic_t_power": 1.0,
323
+ "endpoint_sequence_random_prob_alpha": 0.0,
324
+ "categorical_wrong_from_full_vocab": true,
325
+ "categorical_wrong_from_batch_valid_tokens": false,
326
+ "categorical_wrong_basin_token_ids": "",
327
+ "categorical_wrong_basin_prob": 0.0,
328
+ "categorical_wrong_unigram_prob": 0.0,
329
+ "categorical_wrong_uniform_prob": 0.0,
330
+ "categorical_wrong_prob_floor": 0.0,
331
+ "categorical_wrong_corpus_unigram_path": "",
332
+ "categorical_wrong_corpus_unigram_alpha": 1.0,
333
+ "categorical_wrong_basin_shared_prob": 0.0,
334
+ "categorical_wrong_unigram_shared_prob": 0.0,
335
+ "mask_mixture_original_prob": 0.0,
336
+ "mask_mixture_lowk_prob": 0.0,
337
+ "mask_mixture_lowcorrupt_prob": 0.0,
338
+ "mask_mixture_block_prob": 0.0,
339
+ "mask_mixture_all_prob": 1.0,
340
+ "mask_mixture_lowk_clean_tokens": "0",
341
+ "mask_mixture_lowcorrupt_tokens": "1,2,4,8,16,32,64",
342
+ "mask_mixture_block_tokens": "64,128",
343
+ "simplex_bridge_sampler": "dirichlet",
344
+ "logistic_normal_sigma_min": 0.1,
345
+ "logistic_normal_sigma_max": 1.0,
346
+ "logistic_normal_tau_min": 1.0,
347
+ "logistic_normal_tau_max": 1.0,
348
+ "torch_compile": false,
349
+ "compile_mode": "max-autotune",
350
+ "state_format": "prob",
351
+ "meanflow_weight": 0.0,
352
+ "rollout_train_prob": 0.5,
353
+ "rollout_train_steps": 3,
354
+ "rollout_train_steps_min": 2,
355
+ "rollout_train_infer_steps": 1,
356
+ "rollout_train_time_mode": "sampled_path",
357
+ "rollout_train_s_dist": "uniform",
358
+ "rollout_train_s_min_frac": 0.0,
359
+ "rollout_train_s_max_frac": 0.25,
360
+ "rollout_train_s_beta_alpha": 2.0,
361
+ "rollout_train_s_beta_beta": 6.0,
362
+ "rollout_train_temp": 1.0,
363
+ "rollout_train_max_gamma": 1.0,
364
+ "rollout_train_corrupt_only": true,
365
+ "rollout_train_samplewise": true,
366
+ "rollout_train_compute_always": false,
367
+ "rollout_train_sync_t": true,
368
+ "bridge_noise_init": "logistic_normal",
369
+ "noise_sigma": -1.0,
370
+ "allow_tf32": true,
371
+ "activation_checkpointing": false,
372
+ "activation_checkpoint_interval": 1,
373
+ "activation_checkpoint_scope": "block",
374
+ "ddp_static_graph": false,
375
+ "ddp_gradient_as_bucket_view": true,
376
+ "blocking_data_transfer": false,
377
+ "dataloader_prefetch_factor": 4,
378
+ "full_train_stats": false,
379
+ "tokenized_hf": false,
380
+ "tokenized_pad_token": "pad",
381
+ "elf_conditional_hf": false,
382
+ "record_pad_truncate": false,
383
+ "record_add_eos": false,
384
+ "record_add_special_tokens": false,
385
+ "record_pad_token": "pad",
386
+ "record_shuffle_buffer": 10000,
387
+ "wrap": true,
388
+ "wrap_mode": "stream",
389
+ "wrap_record_buffer_size": 200,
390
+ "owt_cached_chunks": true,
391
+ "owt_chunk_cache_dir": "/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train8_compact_overfit",
392
+ "owt_chunk_cache_rebuild": false,
393
+ "owt_chunk_cache_write_batch": 4096,
394
+ "owt_exact_repeat_per_chunk": 64,
395
+ "online_chunk_shuffle": false,
396
+ "online_chunk_shuffle_buffer": 10000,
397
+ "openwebtext_split": "train_minus_100k",
398
+ "detokenizer": "auto",
399
+ "resolved_detokenizer": null,
400
+ "num_workers": 0,
401
+ "latest_every": 1000,
402
+ "resume_path": "runs/train8_ctx1024_randk_p50_path3_rand2_3_unif0_0p25_ctx1024_uniformt_temp1_randk_20260518_010217/latest.pt"
403
+ }
404
+ step=1100 epoch=1100/2000 epoch_step=1/1 micro_steps=1100 elapsed=21.9s lr=2.000000e-03 loss=1.1732 loss_recon=1.1732 loss_meanflow=0.0000 mean_model_t=0.4971 mean_corrupt_t=0.4971 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.5052 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7453 corrupt_frac=1.0000 acc_corrupt=0.7453 loss_corrupt=1.1732 wrong_frac=0.5028 init_acc_corrupt=0.4994 acc_corrupt_t_0p0_0p2=0.1029 corrupt_frac_t_0p0_0p2=0.2071 acc_corrupt_t_0p2_0p4=0.7067 corrupt_frac_t_0p2_0p4=0.1988 acc_corrupt_t_0p4_0p6=0.9560 corrupt_frac_t_0p4_0p6=0.1976 acc_corrupt_t_0p6_0p8=0.9915 corrupt_frac_t_0p6_0p8=0.1951 acc_corrupt_t_0p8_1p0=0.9987 corrupt_frac_t_0p8_1p0=0.2015 out_w_norm=12.1903 out_g_norm=0.5876 loss_all=0.9405 init_gold_top10=0.6469 init_gold_top100=0.7430 rollout_applied_pos_frac=0.4688 init_acc_rollout_applied=0.5473 init_acc_rollout_kept=0.4381 logit_acc_rollout_applied=0.7878 logit_acc_rollout_kept=0.7687
LTA_openwebtext_dualt/logs/softendpoint_mn_pilot_4gpu/train8_n1024_hard_ce_bridge_20260517_train8_overfit.log ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NCCL version 2.25.1+cuda12.8
2
+ {
3
+ "device": "cuda:0",
4
+ "rank": 0,
5
+ "world_size": 4,
6
+ "samples": "owt_cached_chunks:8",
7
+ "vocab_size": 50257,
8
+ "tokenizer_vocab_size": 50257,
9
+ "save_dir": "runs/train8_n1024_hard_ce_bridge_20260517_train8_overfit",
10
+ "batch_size": 1,
11
+ "grad_accum": 1,
12
+ "effective_batch_size": 4,
13
+ "global_batch_size": 4,
14
+ "lr_schedule": "constant_warmup",
15
+ "optimizer": "muon",
16
+ "epochs": 0.0,
17
+ "steps_per_epoch": 2,
18
+ "total_steps": 1000,
19
+ "warmup_steps": 20,
20
+ "warmup_epochs": -1.0,
21
+ "min_lr": 0.0,
22
+ "weight_decay": 0.1,
23
+ "output_weight_decay": -1.0,
24
+ "adamw_param_groups": "nanogpt",
25
+ "adam_beta1": 0.9,
26
+ "adam_beta2": 0.95,
27
+ "adam_eps": 1e-08,
28
+ "muon_impl": "legacy",
29
+ "muon_momentum": 0.95,
30
+ "muon_ns_steps": 5,
31
+ "muon_update_scale": 1.0,
32
+ "muon_nesterov": false,
33
+ "muon_width_scale": false,
34
+ "muon_grouping": "legacy_dim_ge_2",
35
+ "muon_param_count": 169453056,
36
+ "muon_adam_param_count": 122368,
37
+ "muon_param_names": [
38
+ "vocab_embed.embedding",
39
+ "sigma_map.net.0.weight",
40
+ "sigma_map.net.2.weight",
41
+ "blocks.0.attn_qkv.weight",
42
+ "blocks.0.attn_out.weight",
43
+ "blocks.0.mlp.0.weight",
44
+ "blocks.0.mlp.2.weight",
45
+ "blocks.0.adaLN_modulation.weight",
46
+ "blocks.1.attn_qkv.weight",
47
+ "blocks.1.attn_out.weight",
48
+ "blocks.1.mlp.0.weight",
49
+ "blocks.1.mlp.2.weight",
50
+ "blocks.1.adaLN_modulation.weight",
51
+ "blocks.2.attn_qkv.weight",
52
+ "blocks.2.attn_out.weight",
53
+ "blocks.2.mlp.0.weight",
54
+ "blocks.2.mlp.2.weight",
55
+ "blocks.2.adaLN_modulation.weight",
56
+ "blocks.3.attn_qkv.weight",
57
+ "blocks.3.attn_out.weight",
58
+ "blocks.3.mlp.0.weight",
59
+ "blocks.3.mlp.2.weight",
60
+ "blocks.3.adaLN_modulation.weight",
61
+ "blocks.4.attn_qkv.weight",
62
+ "blocks.4.attn_out.weight",
63
+ "blocks.4.mlp.0.weight",
64
+ "blocks.4.mlp.2.weight",
65
+ "blocks.4.adaLN_modulation.weight",
66
+ "blocks.5.attn_qkv.weight",
67
+ "blocks.5.attn_out.weight",
68
+ "blocks.5.mlp.0.weight",
69
+ "blocks.5.mlp.2.weight",
70
+ "blocks.5.adaLN_modulation.weight",
71
+ "blocks.6.attn_qkv.weight",
72
+ "blocks.6.attn_out.weight",
73
+ "blocks.6.mlp.0.weight",
74
+ "blocks.6.mlp.2.weight",
75
+ "blocks.6.adaLN_modulation.weight",
76
+ "blocks.7.attn_qkv.weight",
77
+ "blocks.7.attn_out.weight",
78
+ "blocks.7.mlp.0.weight",
79
+ "blocks.7.mlp.2.weight",
80
+ "blocks.7.adaLN_modulation.weight",
81
+ "blocks.8.attn_qkv.weight",
82
+ "blocks.8.attn_out.weight",
83
+ "blocks.8.mlp.0.weight",
84
+ "blocks.8.mlp.2.weight",
85
+ "blocks.8.adaLN_modulation.weight",
86
+ "blocks.9.attn_qkv.weight",
87
+ "blocks.9.attn_out.weight",
88
+ "blocks.9.mlp.0.weight",
89
+ "blocks.9.mlp.2.weight",
90
+ "blocks.9.adaLN_modulation.weight",
91
+ "blocks.10.attn_qkv.weight",
92
+ "blocks.10.attn_out.weight",
93
+ "blocks.10.mlp.0.weight",
94
+ "blocks.10.mlp.2.weight",
95
+ "blocks.10.adaLN_modulation.weight",
96
+ "blocks.11.attn_qkv.weight",
97
+ "blocks.11.attn_out.weight",
98
+ "blocks.11.mlp.0.weight",
99
+ "blocks.11.mlp.2.weight",
100
+ "blocks.11.adaLN_modulation.weight",
101
+ "output_layer.linear.weight",
102
+ "output_layer.adaLN_modulation.weight"
103
+ ],
104
+ "muon_adam_param_names": [
105
+ "sigma_map.net.0.bias",
106
+ "sigma_map.net.2.bias",
107
+ "blocks.0.norm1.weight",
108
+ "blocks.0.norm2.weight",
109
+ "blocks.0.mlp.0.bias",
110
+ "blocks.0.mlp.2.bias",
111
+ "blocks.0.adaLN_modulation.bias",
112
+ "blocks.1.norm1.weight",
113
+ "blocks.1.norm2.weight",
114
+ "blocks.1.mlp.0.bias",
115
+ "blocks.1.mlp.2.bias",
116
+ "blocks.1.adaLN_modulation.bias",
117
+ "blocks.2.norm1.weight",
118
+ "blocks.2.norm2.weight",
119
+ "blocks.2.mlp.0.bias",
120
+ "blocks.2.mlp.2.bias",
121
+ "blocks.2.adaLN_modulation.bias",
122
+ "blocks.3.norm1.weight",
123
+ "blocks.3.norm2.weight",
124
+ "blocks.3.mlp.0.bias",
125
+ "blocks.3.mlp.2.bias",
126
+ "blocks.3.adaLN_modulation.bias",
127
+ "blocks.4.norm1.weight",
128
+ "blocks.4.norm2.weight",
129
+ "blocks.4.mlp.0.bias",
130
+ "blocks.4.mlp.2.bias",
131
+ "blocks.4.adaLN_modulation.bias",
132
+ "blocks.5.norm1.weight",
133
+ "blocks.5.norm2.weight",
134
+ "blocks.5.mlp.0.bias",
135
+ "blocks.5.mlp.2.bias",
136
+ "blocks.5.adaLN_modulation.bias",
137
+ "blocks.6.norm1.weight",
138
+ "blocks.6.norm2.weight",
139
+ "blocks.6.mlp.0.bias",
140
+ "blocks.6.mlp.2.bias",
141
+ "blocks.6.adaLN_modulation.bias",
142
+ "blocks.7.norm1.weight",
143
+ "blocks.7.norm2.weight",
144
+ "blocks.7.mlp.0.bias",
145
+ "blocks.7.mlp.2.bias",
146
+ "blocks.7.adaLN_modulation.bias",
147
+ "blocks.8.norm1.weight",
148
+ "blocks.8.norm2.weight",
149
+ "blocks.8.mlp.0.bias",
150
+ "blocks.8.mlp.2.bias",
151
+ "blocks.8.adaLN_modulation.bias",
152
+ "blocks.9.norm1.weight",
153
+ "blocks.9.norm2.weight",
154
+ "blocks.9.mlp.0.bias",
155
+ "blocks.9.mlp.2.bias",
156
+ "blocks.9.adaLN_modulation.bias",
157
+ "blocks.10.norm1.weight",
158
+ "blocks.10.norm2.weight",
159
+ "blocks.10.mlp.0.bias",
160
+ "blocks.10.mlp.2.bias",
161
+ "blocks.10.adaLN_modulation.bias",
162
+ "blocks.11.norm1.weight",
163
+ "blocks.11.norm2.weight",
164
+ "blocks.11.mlp.0.bias",
165
+ "blocks.11.mlp.2.bias",
166
+ "blocks.11.adaLN_modulation.bias",
167
+ "output_layer.norm_final.weight",
168
+ "output_layer.adaLN_modulation.bias"
169
+ ],
170
+ "muon_effective_nesterov": false,
171
+ "muon_effective_width_scale": false,
172
+ "muon_effective_weight_decay": 0.1,
173
+ "muon_adam_fallback_nesterov": false,
174
+ "muon_adam_fallback_weight_decay": 0.1,
175
+ "ema_decay": 0.9999,
176
+ "ema_start_step": 0,
177
+ "model_type": "ddit",
178
+ "elf_num_time_tokens": 4,
179
+ "elf_num_model_mode_tokens": 0,
180
+ "qk_norm": true,
181
+ "output_bias": false,
182
+ "output_init_std": -1.0,
183
+ "norm_type": "rmsnorm",
184
+ "target_loss": "hard_ce",
185
+ "linear_soft_target_power": 1.0,
186
+ "linear_soft_target_min_conf": 0.0,
187
+ "linear_soft_target_max_conf": 1.0,
188
+ "t_sampling_mode": "logit_normal",
189
+ "t_sampling_power": 1.0,
190
+ "t_sampling_eps": 0.0001,
191
+ "t_sampling_logit_mean": -1.5,
192
+ "t_sampling_logit_std": 0.8,
193
+ "dual_t": true,
194
+ "corrupt_t_mode": "same",
195
+ "corrupt_min_t": 0.0,
196
+ "corrupt_max_t": 1.0,
197
+ "prefix_block_prob": 0.0,
198
+ "prefix_block_len": 128,
199
+ "mask_ratio_floor_schedule": "none",
200
+ "dirichlet_endpoint_mode": "categorical_dual_t",
201
+ "dirichlet_semantic_t_mode": "same",
202
+ "dirichlet_semantic_t_value": 0.0,
203
+ "dirichlet_semantic_t_curve": "linear",
204
+ "dirichlet_semantic_t_power": 1.0,
205
+ "endpoint_sequence_random_prob_alpha": 0.0,
206
+ "categorical_wrong_from_full_vocab": true,
207
+ "categorical_wrong_from_batch_valid_tokens": false,
208
+ "categorical_wrong_basin_token_ids": "",
209
+ "categorical_wrong_basin_prob": 0.0,
210
+ "categorical_wrong_unigram_prob": 0.0,
211
+ "categorical_wrong_uniform_prob": 0.0,
212
+ "categorical_wrong_corpus_unigram_path": "",
213
+ "categorical_wrong_corpus_unigram_alpha": 1.0,
214
+ "categorical_wrong_basin_shared_prob": 0.0,
215
+ "categorical_wrong_unigram_shared_prob": 0.0,
216
+ "mask_mixture_original_prob": 0.0,
217
+ "mask_mixture_lowk_prob": 1.0,
218
+ "mask_mixture_lowcorrupt_prob": 0.0,
219
+ "mask_mixture_block_prob": 0.0,
220
+ "mask_mixture_all_prob": 0.0,
221
+ "mask_mixture_lowk_clean_tokens": "64,128,256",
222
+ "mask_mixture_lowcorrupt_tokens": "1,2,4,8,16,32,64",
223
+ "mask_mixture_block_tokens": "64,128",
224
+ "simplex_bridge_sampler": "dirichlet",
225
+ "logistic_normal_sigma_min": 0.18,
226
+ "logistic_normal_sigma_max": 2.2,
227
+ "logistic_normal_tau_min": 0.65,
228
+ "logistic_normal_tau_max": 1.15,
229
+ "torch_compile": false,
230
+ "compile_mode": "max-autotune",
231
+ "state_format": "prob",
232
+ "meanflow_weight": 0.0,
233
+ "rollout_train_prob": 0.0,
234
+ "rollout_train_steps": 1,
235
+ "rollout_train_infer_steps": 64,
236
+ "rollout_train_temp": 1.45,
237
+ "rollout_train_max_gamma": 1.0,
238
+ "rollout_train_corrupt_only": true,
239
+ "rollout_train_samplewise": false,
240
+ "rollout_train_compute_always": false,
241
+ "bridge_noise_init": "logistic_normal",
242
+ "noise_sigma": -1.0,
243
+ "allow_tf32": true,
244
+ "activation_checkpointing": false,
245
+ "activation_checkpoint_interval": 1,
246
+ "activation_checkpoint_scope": "block",
247
+ "ddp_static_graph": false,
248
+ "ddp_gradient_as_bucket_view": true,
249
+ "blocking_data_transfer": false,
250
+ "dataloader_prefetch_factor": 4,
251
+ "full_train_stats": false,
252
+ "tokenized_hf": false,
253
+ "tokenized_pad_token": "pad",
254
+ "elf_conditional_hf": false,
255
+ "record_pad_truncate": false,
256
+ "record_add_eos": false,
257
+ "record_add_special_tokens": false,
258
+ "record_pad_token": "pad",
259
+ "record_shuffle_buffer": 10000,
260
+ "wrap": true,
261
+ "wrap_mode": "stream",
262
+ "wrap_record_buffer_size": 200,
263
+ "owt_cached_chunks": true,
264
+ "owt_chunk_cache_dir": "/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train8_overfit",
265
+ "owt_chunk_cache_rebuild": false,
266
+ "owt_chunk_cache_write_batch": 4096,
267
+ "owt_exact_repeat_per_chunk": 0,
268
+ "online_chunk_shuffle": false,
269
+ "online_chunk_shuffle_buffer": 10000,
270
+ "openwebtext_split": "train_minus_100k",
271
+ "detokenizer": "auto",
272
+ "resolved_detokenizer": null,
273
+ "num_workers": 0,
274
+ "latest_every": 50,
275
+ "resume_path": ""
276
+ }
277
+ step=25 epoch=13/500 epoch_step=1/2 micro_steps=25 elapsed=3.9s lr=2.000000e-03 loss=10.7864 loss_recon=10.7864 loss_meanflow=0.0000 mean_model_t=0.2229 mean_corrupt_t=0.2229 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1332 corrupt_frac=0.8125 acc_corrupt=0.1000 loss_corrupt=10.7864 wrong_frac=0.7737 init_acc_corrupt=0.1321 acc_corrupt_t_0p0_0p2=0.0574 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=0.2483 out_g_norm=1.5841 acc_corrupt_t_0p2_0p4=0.1211 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.2533 corrupt_frac_t_0p4_0p6=1.0000 loss_all=10.6718 init_gold_top10=0.3646 init_gold_top100=0.3646
278
+ step=50 epoch=25/500 epoch_step=2/2 micro_steps=50 elapsed=3.2s lr=2.000000e-03 loss=10.6536 loss_recon=10.6536 loss_meanflow=0.0000 mean_model_t=0.2418 mean_corrupt_t=0.2418 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1506 corrupt_frac=0.8675 acc_corrupt=0.1229 loss_corrupt=10.6536 wrong_frac=0.7587 init_acc_corrupt=0.1501 acc_corrupt_t_0p0_0p2=0.0635 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=1.1359 out_g_norm=2.1812 acc_corrupt_t_0p2_0p4=0.1465 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.2672 corrupt_frac_t_0p4_0p6=1.0000 loss_all=10.6088 init_gold_top10=0.1135 init_gold_top100=0.2458
279
+ step=75 epoch=38/500 epoch_step=1/2 micro_steps=75 elapsed=7.1s lr=2.000000e-03 loss=10.4160 loss_recon=10.4160 loss_meanflow=0.0000 mean_model_t=0.1995 mean_corrupt_t=0.1995 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1303 corrupt_frac=0.8675 acc_corrupt=0.1041 loss_corrupt=10.4160 wrong_frac=0.7995 init_acc_corrupt=0.1001 acc_corrupt_t_0p2_0p4=0.1574 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=1.7525 out_g_norm=2.6698 acc_corrupt_t_0p0_0p2=0.0612 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p4_0p6=0.2458 corrupt_frac_t_0p4_0p6=1.0000 loss_all=9.9453 init_gold_top10=0.3073 init_gold_top100=0.3292
280
+ step=100 epoch=50/500 epoch_step=2/2 micro_steps=100 elapsed=3.2s lr=2.000000e-03 loss=10.0502 loss_recon=10.0502 loss_meanflow=0.0000 mean_model_t=0.2487 mean_corrupt_t=0.2487 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1689 corrupt_frac=0.8425 acc_corrupt=0.1328 loss_corrupt=10.0502 wrong_frac=0.7557 init_acc_corrupt=0.1566 acc_corrupt_t_0p0_0p2=0.0632 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=2.2511 out_g_norm=2.9631 acc_corrupt_t_0p2_0p4=0.1522 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.2891 corrupt_frac_t_0p4_0p6=1.0000 loss_all=9.5690 init_gold_top10=0.3372 init_gold_top100=0.3411
281
+ step=125 epoch=63/500 epoch_step=1/2 micro_steps=125 elapsed=7.3s lr=2.000000e-03 loss=9.8011 loss_recon=9.8011 loss_meanflow=0.0000 mean_model_t=0.2121 mean_corrupt_t=0.2121 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1487 corrupt_frac=0.8450 acc_corrupt=0.1153 loss_corrupt=9.8011 wrong_frac=0.7876 init_acc_corrupt=0.1138 acc_corrupt_t_0p0_0p2=0.0628 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=2.7598 out_g_norm=3.4495 acc_corrupt_t_0p2_0p4=0.1326 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.3001 corrupt_frac_t_0p4_0p6=1.0000 loss_all=9.5979 init_gold_top10=0.1384 init_gold_top100=0.2511
282
+ step=150 epoch=75/500 epoch_step=2/2 micro_steps=150 elapsed=3.2s lr=2.000000e-03 loss=9.3065 loss_recon=9.3065 loss_meanflow=0.0000 mean_model_t=0.2285 mean_corrupt_t=0.2285 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1630 corrupt_frac=0.8525 acc_corrupt=0.1323 loss_corrupt=9.3065 wrong_frac=0.7725 init_acc_corrupt=0.1303 acc_corrupt_t_0p2_0p4=0.1598 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=3.2724 out_g_norm=3.7085 acc_corrupt_t_0p4_0p6=0.3238 corrupt_frac_t_0p4_0p6=1.0000 acc_corrupt_t_0p0_0p2=0.0682 corrupt_frac_t_0p0_0p2=1.0000 loss_all=8.2090 init_gold_top10=0.2612 init_gold_top100=0.2991
283
+ step=175 epoch=88/500 epoch_step=1/2 micro_steps=175 elapsed=7.2s lr=2.000000e-03 loss=9.0808 loss_recon=9.0808 loss_meanflow=0.0000 mean_model_t=0.2193 mean_corrupt_t=0.2193 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1660 corrupt_frac=0.8375 acc_corrupt=0.1244 loss_corrupt=9.0808 wrong_frac=0.7870 init_acc_corrupt=0.1225 acc_corrupt_t_0p4_0p6=0.3179 corrupt_frac_t_0p4_0p6=1.0000 out_w_norm=3.7787 out_g_norm=3.8401 acc_corrupt_t_0p2_0p4=0.1766 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p0_0p2=0.0676 corrupt_frac_t_0p0_0p2=1.0000 loss_all=7.1403 init_gold_top10=0.3661 init_gold_top100=0.3694
284
+ step=200 epoch=100/500 epoch_step=2/2 micro_steps=200 elapsed=3.2s lr=2.000000e-03 loss=8.9196 loss_recon=8.9196 loss_meanflow=0.0000 mean_model_t=0.2000 mean_corrupt_t=0.2000 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1503 corrupt_frac=0.8400 acc_corrupt=0.1162 loss_corrupt=8.9196 wrong_frac=0.7967 init_acc_corrupt=0.1161 acc_corrupt_t_0p0_0p2=0.0644 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=4.2129 out_g_norm=3.8417 acc_corrupt_t_0p2_0p4=0.1556 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.2713 corrupt_frac_t_0p4_0p6=1.0000 loss_all=9.1738 init_gold_top10=0.0104 init_gold_top100=0.1367
285
+ step=225 epoch=113/500 epoch_step=1/2 micro_steps=225 elapsed=6.5s lr=2.000000e-03 loss=8.5644 loss_recon=8.5644 loss_meanflow=0.0000 mean_model_t=0.2059 mean_corrupt_t=0.2059 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1668 corrupt_frac=0.8475 acc_corrupt=0.1290 loss_corrupt=8.5644 wrong_frac=0.7966 init_acc_corrupt=0.1126 acc_corrupt_t_0p0_0p2=0.0478 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=4.5749 out_g_norm=3.5982 acc_corrupt_t_0p2_0p4=0.1699 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.3385 corrupt_frac_t_0p4_0p6=1.0000 loss_all=7.4746 init_gold_top10=0.2122 init_gold_top100=0.3138
286
+ step=250 epoch=125/500 epoch_step=2/2 micro_steps=250 elapsed=3.2s lr=2.000000e-03 loss=8.4566 loss_recon=8.4566 loss_meanflow=0.0000 mean_model_t=0.1793 mean_corrupt_t=0.1793 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1480 corrupt_frac=0.8500 acc_corrupt=0.1107 loss_corrupt=8.4566 wrong_frac=0.8201 init_acc_corrupt=0.0854 acc_corrupt_t_0p2_0p4=0.1787 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=4.8587 out_g_norm=3.4425 acc_corrupt_t_0p0_0p2=0.0583 corrupt_frac_t_0p0_0p2=1.0000 loss_all=8.5487 init_gold_top10=0.0837 init_gold_top100=0.1987
287
+ step=275 epoch=138/500 epoch_step=1/2 micro_steps=275 elapsed=7.4s lr=2.000000e-03 loss=8.2703 loss_recon=8.2703 loss_meanflow=0.0000 mean_model_t=0.2038 mean_corrupt_t=0.2038 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1492 corrupt_frac=0.8600 acc_corrupt=0.1137 loss_corrupt=8.2703 wrong_frac=0.8049 init_acc_corrupt=0.0914 acc_corrupt_t_0p0_0p2=0.0710 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=5.0798 out_g_norm=3.1796 acc_corrupt_t_0p2_0p4=0.1635 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p6_0p8=0.3464 corrupt_frac_t_0p6_0p8=1.0000 loss_all=8.2362 init_gold_top10=0.0182 init_gold_top100=0.1224
288
+ step=300 epoch=150/500 epoch_step=2/2 micro_steps=300 elapsed=3.2s lr=2.000000e-03 loss=8.1003 loss_recon=8.1003 loss_meanflow=0.0000 mean_model_t=0.1754 mean_corrupt_t=0.1754 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1366 corrupt_frac=0.8725 acc_corrupt=0.1104 loss_corrupt=8.1003 wrong_frac=0.8226 init_acc_corrupt=0.0753 acc_corrupt_t_0p2_0p4=0.1806 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=5.2769 out_g_norm=2.9771 acc_corrupt_t_0p0_0p2=0.0692 corrupt_frac_t_0p0_0p2=1.0000 loss_all=6.8203 init_gold_top10=0.2604 init_gold_top100=0.3187
289
+ step=325 epoch=163/500 epoch_step=1/2 micro_steps=325 elapsed=7.2s lr=2.000000e-03 loss=7.5749 loss_recon=7.5749 loss_meanflow=0.0000 mean_model_t=0.2307 mean_corrupt_t=0.2307 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1913 corrupt_frac=0.8425 acc_corrupt=0.1466 loss_corrupt=7.5749 wrong_frac=0.7709 init_acc_corrupt=0.1407 acc_corrupt_t_0p2_0p4=0.1861 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=5.4730 out_g_norm=2.6582 acc_corrupt_t_0p0_0p2=0.0727 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p4_0p6=0.3095 corrupt_frac_t_0p4_0p6=1.0000 loss_all=7.2629 init_gold_top10=0.1471 init_gold_top100=0.2539
290
+ step=350 epoch=175/500 epoch_step=2/2 micro_steps=350 elapsed=3.2s lr=2.000000e-03 loss=7.7026 loss_recon=7.7026 loss_meanflow=0.0000 mean_model_t=0.2000 mean_corrupt_t=0.2000 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1591 corrupt_frac=0.8525 acc_corrupt=0.1239 loss_corrupt=7.7026 wrong_frac=0.7984 init_acc_corrupt=0.1058 acc_corrupt_t_0p2_0p4=0.1731 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=5.6895 out_g_norm=2.5375 acc_corrupt_t_0p0_0p2=0.0568 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p4_0p6=0.3149 corrupt_frac_t_0p4_0p6=1.0000 loss_all=7.8661 init_gold_top10=0.0716 init_gold_top100=0.1667
291
+ step=375 epoch=188/500 epoch_step=1/2 micro_steps=375 elapsed=6.0s lr=2.000000e-03 loss=7.5492 loss_recon=7.5492 loss_meanflow=0.0000 mean_model_t=0.1848 mean_corrupt_t=0.1848 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1554 corrupt_frac=0.8275 acc_corrupt=0.1158 loss_corrupt=7.5492 wrong_frac=0.8215 init_acc_corrupt=0.0898 acc_corrupt_t_0p0_0p2=0.0650 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=5.8912 out_g_norm=2.7020 acc_corrupt_t_0p2_0p4=0.1840 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.3259 corrupt_frac_t_0p4_0p6=1.0000 loss_all=6.7869 init_gold_top10=0.2943 init_gold_top100=0.3060
292
+ step=400 epoch=200/500 epoch_step=2/2 micro_steps=400 elapsed=3.2s lr=2.000000e-03 loss=6.8606 loss_recon=6.8606 loss_meanflow=0.0000 mean_model_t=0.2432 mean_corrupt_t=0.2432 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2130 corrupt_frac=0.8475 acc_corrupt=0.1721 loss_corrupt=6.8606 wrong_frac=0.7520 init_acc_corrupt=0.1677 acc_corrupt_t_0p4_0p6=0.4051 corrupt_frac_t_0p4_0p6=1.0000 out_w_norm=6.0772 out_g_norm=2.5310 acc_corrupt_t_0p0_0p2=0.0706 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p2_0p4=0.2320 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p6_0p8=0.4927 corrupt_frac_t_0p6_0p8=1.0000 loss_all=7.3142 init_gold_top10=0.1354 init_gold_top100=0.2435
293
+ step=425 epoch=213/500 epoch_step=1/2 micro_steps=425 elapsed=6.6s lr=2.000000e-03 loss=6.8943 loss_recon=6.8943 loss_meanflow=0.0000 mean_model_t=0.2254 mean_corrupt_t=0.2254 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2041 corrupt_frac=0.8475 acc_corrupt=0.1616 loss_corrupt=6.8943 wrong_frac=0.7710 init_acc_corrupt=0.1320 acc_corrupt_t_0p0_0p2=0.0746 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=6.2709 out_g_norm=2.3976 acc_corrupt_t_0p2_0p4=0.2291 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.4657 corrupt_frac_t_0p4_0p6=1.0000 loss_all=8.1770 init_gold_top10=0.1339 init_gold_top100=0.2333
294
+ step=450 epoch=225/500 epoch_step=2/2 micro_steps=450 elapsed=3.2s lr=2.000000e-03 loss=6.5882 loss_recon=6.5882 loss_meanflow=0.0000 mean_model_t=0.2573 mean_corrupt_t=0.2573 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2300 corrupt_frac=0.8425 acc_corrupt=0.1835 loss_corrupt=6.5882 wrong_frac=0.7421 init_acc_corrupt=0.1684 acc_corrupt_t_0p2_0p4=0.1989 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=6.4616 out_g_norm=2.4658 acc_corrupt_t_0p0_0p2=0.0797 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p4_0p6=0.4002 corrupt_frac_t_0p4_0p6=1.0000 acc_corrupt_t_0p6_0p8=0.4933 corrupt_frac_t_0p6_0p8=1.0000 loss_all=6.0882 init_gold_top10=0.2148 init_gold_top100=0.3008
295
+ step=475 epoch=238/500 epoch_step=1/2 micro_steps=475 elapsed=7.2s lr=2.000000e-03 loss=6.6024 loss_recon=6.6024 loss_meanflow=0.0000 mean_model_t=0.2275 mean_corrupt_t=0.2275 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2090 corrupt_frac=0.8325 acc_corrupt=0.1627 loss_corrupt=6.6024 wrong_frac=0.7694 init_acc_corrupt=0.1501 acc_corrupt_t_0p0_0p2=0.0717 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=6.6827 out_g_norm=2.4195 acc_corrupt_t_0p2_0p4=0.2184 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.3401 corrupt_frac_t_0p4_0p6=1.0000 loss_all=7.1854 init_gold_top10=0.0982 init_gold_top100=0.2254
296
+ step=500 epoch=250/500 epoch_step=2/2 micro_steps=500 elapsed=3.2s lr=2.000000e-03 loss=6.6715 loss_recon=6.6715 loss_meanflow=0.0000 mean_model_t=0.2087 mean_corrupt_t=0.2087 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1925 corrupt_frac=0.8325 acc_corrupt=0.1487 loss_corrupt=6.6715 wrong_frac=0.7906 init_acc_corrupt=0.1262 acc_corrupt_t_0p2_0p4=0.2258 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=6.8888 out_g_norm=2.5030 acc_corrupt_t_0p0_0p2=0.0652 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p4_0p6=0.4675 corrupt_frac_t_0p4_0p6=1.0000 loss_all=5.3250 init_gold_top10=0.3604 init_gold_top100=0.3615
297
+ step=525 epoch=263/500 epoch_step=1/2 micro_steps=525 elapsed=6.0s lr=2.000000e-03 loss=6.4744 loss_recon=6.4744 loss_meanflow=0.0000 mean_model_t=0.2178 mean_corrupt_t=0.2178 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1979 corrupt_frac=0.8500 acc_corrupt=0.1515 loss_corrupt=6.4744 wrong_frac=0.7787 init_acc_corrupt=0.1194 acc_corrupt_t_0p2_0p4=0.1977 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=7.0787 out_g_norm=2.5506 acc_corrupt_t_0p0_0p2=0.0755 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p4_0p6=0.3393 corrupt_frac_t_0p4_0p6=1.0000 loss_all=4.8922 init_gold_top10=0.4342 init_gold_top100=0.4353
298
+ step=550 epoch=275/500 epoch_step=2/2 micro_steps=550 elapsed=3.2s lr=2.000000e-03 loss=6.2602 loss_recon=6.2602 loss_meanflow=0.0000 mean_model_t=0.2122 mean_corrupt_t=0.2122 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2027 corrupt_frac=0.8475 acc_corrupt=0.1587 loss_corrupt=6.2602 wrong_frac=0.7820 init_acc_corrupt=0.1212 acc_corrupt_t_0p0_0p2=0.0710 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=7.2750 out_g_norm=2.6923 acc_corrupt_t_0p2_0p4=0.2271 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.3438 corrupt_frac_t_0p4_0p6=1.0000 loss_all=7.6317 init_gold_top10=0.0145 init_gold_top100=0.1194
299
+ step=575 epoch=288/500 epoch_step=1/2 micro_steps=575 elapsed=6.0s lr=2.000000e-03 loss=6.6165 loss_recon=6.6165 loss_meanflow=0.0000 mean_model_t=0.1913 mean_corrupt_t=0.1913 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1636 corrupt_frac=0.8675 acc_corrupt=0.1295 loss_corrupt=6.6165 wrong_frac=0.8059 init_acc_corrupt=0.0965 acc_corrupt_t_0p0_0p2=0.0527 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=7.4940 out_g_norm=2.8035 acc_corrupt_t_0p2_0p4=0.1950 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p6_0p8=0.6003 corrupt_frac_t_0p6_0p8=1.0000 loss_all=8.5299 init_gold_top10=0.0146 init_gold_top100=0.1229
300
+ step=600 epoch=300/500 epoch_step=2/2 micro_steps=600 elapsed=3.2s lr=2.000000e-03 loss=6.3523 loss_recon=6.3523 loss_meanflow=0.0000 mean_model_t=0.2002 mean_corrupt_t=0.2002 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1777 corrupt_frac=0.8600 acc_corrupt=0.1367 loss_corrupt=6.3523 wrong_frac=0.8077 init_acc_corrupt=0.1017 acc_corrupt_t_0p2_0p4=0.1968 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=7.6925 out_g_norm=2.6466 acc_corrupt_t_0p0_0p2=0.0663 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p4_0p6=0.3737 corrupt_frac_t_0p4_0p6=1.0000 loss_all=4.8443 init_gold_top10=0.2422 init_gold_top100=0.3177
301
+ step=625 epoch=313/500 epoch_step=1/2 micro_steps=625 elapsed=7.0s lr=2.000000e-03 loss=6.1065 loss_recon=6.1065 loss_meanflow=0.0000 mean_model_t=0.1977 mean_corrupt_t=0.1977 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1844 corrupt_frac=0.8650 acc_corrupt=0.1472 loss_corrupt=6.1065 wrong_frac=0.8011 init_acc_corrupt=0.1018 acc_corrupt_t_0p2_0p4=0.2090 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=7.8705 out_g_norm=2.7706 acc_corrupt_t_0p0_0p2=0.0697 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p4_0p6=0.4319 corrupt_frac_t_0p4_0p6=1.0000 loss_all=4.5681 init_gold_top10=0.3052 init_gold_top100=0.3052
302
+ step=650 epoch=325/500 epoch_step=2/2 micro_steps=650 elapsed=3.2s lr=2.000000e-03 loss=5.3613 loss_recon=5.3613 loss_meanflow=0.0000 mean_model_t=0.2686 mean_corrupt_t=0.2686 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2696 corrupt_frac=0.8500 acc_corrupt=0.2165 loss_corrupt=5.3613 wrong_frac=0.7288 init_acc_corrupt=0.1849 acc_corrupt_t_0p2_0p4=0.2419 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=8.0587 out_g_norm=2.9153 acc_corrupt_t_0p0_0p2=0.0758 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p4_0p6=0.3771 corrupt_frac_t_0p4_0p6=1.0000 acc_corrupt_t_0p6_0p8=0.5826 corrupt_frac_t_0p6_0p8=1.0000 loss_all=6.6134 init_gold_top10=0.0586 init_gold_top100=0.2070
303
+ step=675 epoch=338/500 epoch_step=1/2 micro_steps=675 elapsed=6.0s lr=2.000000e-03 loss=5.6003 loss_recon=5.6003 loss_meanflow=0.0000 mean_model_t=0.2244 mean_corrupt_t=0.2244 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2217 corrupt_frac=0.8450 acc_corrupt=0.1732 loss_corrupt=5.6003 wrong_frac=0.7735 init_acc_corrupt=0.1386 acc_corrupt_t_0p2_0p4=0.2217 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=8.2271 out_g_norm=3.2682 acc_corrupt_t_0p0_0p2=0.0735 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p4_0p6=0.3936 corrupt_frac_t_0p4_0p6=1.0000 loss_all=4.4996 init_gold_top10=0.2533 init_gold_top100=0.2935
304
+ step=700 epoch=350/500 epoch_step=2/2 micro_steps=700 elapsed=3.2s lr=2.000000e-03 loss=5.7057 loss_recon=5.7057 loss_meanflow=0.0000 mean_model_t=0.1820 mean_corrupt_t=0.1820 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1834 corrupt_frac=0.8550 acc_corrupt=0.1393 loss_corrupt=5.7057 wrong_frac=0.8183 init_acc_corrupt=0.0776 acc_corrupt_t_0p2_0p4=0.1946 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=8.4344 out_g_norm=3.0268 acc_corrupt_t_0p0_0p2=0.0733 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p4_0p6=0.3880 corrupt_frac_t_0p4_0p6=1.0000 loss_all=4.3138 init_gold_top10=0.3094 init_gold_top100=0.3146
305
+ step=725 epoch=363/500 epoch_step=1/2 micro_steps=725 elapsed=6.0s lr=2.000000e-03 loss=5.3601 loss_recon=5.3601 loss_meanflow=0.0000 mean_model_t=0.2167 mean_corrupt_t=0.2167 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2263 corrupt_frac=0.8275 acc_corrupt=0.1640 loss_corrupt=5.3601 wrong_frac=0.7835 init_acc_corrupt=0.1159 acc_corrupt_t_0p0_0p2=0.0829 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=8.6344 out_g_norm=3.1377 acc_corrupt_t_0p2_0p4=0.2374 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.4219 corrupt_frac_t_0p4_0p6=1.0000 loss_all=4.7756 init_gold_top10=0.1979 init_gold_top100=0.2875
306
+ step=750 epoch=375/500 epoch_step=2/2 micro_steps=750 elapsed=3.2s lr=2.000000e-03 loss=5.4931 loss_recon=5.4931 loss_meanflow=0.0000 mean_model_t=0.2249 mean_corrupt_t=0.2249 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2207 corrupt_frac=0.8575 acc_corrupt=0.1721 loss_corrupt=5.4931 wrong_frac=0.7773 init_acc_corrupt=0.1290 acc_corrupt_t_0p2_0p4=0.2201 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=8.8160 out_g_norm=3.4480 acc_corrupt_t_0p4_0p6=0.4258 corrupt_frac_t_0p4_0p6=1.0000 acc_corrupt_t_0p8_1p0=0.7790 corrupt_frac_t_0p8_1p0=1.0000 acc_corrupt_t_0p0_0p2=0.0803 corrupt_frac_t_0p0_0p2=1.0000 loss_all=4.8351 init_gold_top10=0.1654 init_gold_top100=0.2630
307
+ step=775 epoch=388/500 epoch_step=1/2 micro_steps=775 elapsed=6.8s lr=2.000000e-03 loss=5.2346 loss_recon=5.2346 loss_meanflow=0.0000 mean_model_t=0.2047 mean_corrupt_t=0.2047 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2190 corrupt_frac=0.8525 acc_corrupt=0.1645 loss_corrupt=5.2346 wrong_frac=0.8014 init_acc_corrupt=0.1122 acc_corrupt_t_0p0_0p2=0.0713 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=8.9976 out_g_norm=3.1033 acc_corrupt_t_0p2_0p4=0.2526 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.4590 corrupt_frac_t_0p4_0p6=1.0000 loss_all=4.2737 init_gold_top10=0.2271 init_gold_top100=0.3146
308
+ step=800 epoch=400/500 epoch_step=2/2 micro_steps=800 elapsed=3.2s lr=2.000000e-03 loss=4.9681 loss_recon=4.9681 loss_meanflow=0.0000 mean_model_t=0.2219 mean_corrupt_t=0.2219 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2404 corrupt_frac=0.8575 acc_corrupt=0.1852 loss_corrupt=4.9681 wrong_frac=0.7766 init_acc_corrupt=0.1295 acc_corrupt_t_0p4_0p6=0.4122 corrupt_frac_t_0p4_0p6=1.0000 out_w_norm=9.1570 out_g_norm=3.3703 acc_corrupt_t_0p0_0p2=0.0803 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p2_0p4=0.2411 corrupt_frac_t_0p2_0p4=1.0000 loss_all=5.4143 init_gold_top10=0.1094 init_gold_top100=0.2176
309
+ step=825 epoch=413/500 epoch_step=1/2 micro_steps=825 elapsed=6.0s lr=2.000000e-03 loss=5.3848 loss_recon=5.3848 loss_meanflow=0.0000 mean_model_t=0.1673 mean_corrupt_t=0.1673 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1622 corrupt_frac=0.8700 acc_corrupt=0.1154 loss_corrupt=5.3848 wrong_frac=0.8311 init_acc_corrupt=0.0565 acc_corrupt_t_0p0_0p2=0.0731 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=9.3244 out_g_norm=3.7546 acc_corrupt_t_0p2_0p4=0.1967 corrupt_frac_t_0p2_0p4=1.0000 loss_all=7.0083 init_gold_top10=0.0188 init_gold_top100=0.1240
310
+ step=850 epoch=425/500 epoch_step=2/2 micro_steps=850 elapsed=3.2s lr=2.000000e-03 loss=4.8082 loss_recon=4.8082 loss_meanflow=0.0000 mean_model_t=0.2215 mean_corrupt_t=0.2215 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2529 corrupt_frac=0.8475 acc_corrupt=0.1910 loss_corrupt=4.8082 wrong_frac=0.7834 init_acc_corrupt=0.1358 acc_corrupt_t_0p0_0p2=0.0694 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=9.4903 out_g_norm=3.4038 acc_corrupt_t_0p2_0p4=0.2639 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.5268 corrupt_frac_t_0p4_0p6=1.0000 loss_all=4.8356 init_gold_top10=0.0982 init_gold_top100=0.2054
311
+ step=875 epoch=438/500 epoch_step=1/2 micro_steps=875 elapsed=6.0s lr=2.000000e-03 loss=4.2792 loss_recon=4.2792 loss_meanflow=0.0000 mean_model_t=0.2682 mean_corrupt_t=0.2682 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.3117 corrupt_frac=0.8500 acc_corrupt=0.2463 loss_corrupt=4.2792 wrong_frac=0.7365 init_acc_corrupt=0.1897 acc_corrupt_t_0p4_0p6=0.4762 corrupt_frac_t_0p4_0p6=1.0000 out_w_norm=9.6246 out_g_norm=3.7357 acc_corrupt_t_0p0_0p2=0.0808 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p2_0p4=0.2545 corrupt_frac_t_0p2_0p4=1.0000 loss_all=2.5656 init_gold_top10=0.4475 init_gold_top100=0.4487
312
+ step=900 epoch=450/500 epoch_step=2/2 micro_steps=900 elapsed=3.2s lr=2.000000e-03 loss=5.0254 loss_recon=5.0254 loss_meanflow=0.0000 mean_model_t=0.1950 mean_corrupt_t=0.1950 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2067 corrupt_frac=0.8675 acc_corrupt=0.1626 loss_corrupt=5.0254 wrong_frac=0.8093 init_acc_corrupt=0.1086 acc_corrupt_t_0p0_0p2=0.0668 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=9.7542 out_g_norm=4.1217 acc_corrupt_t_0p2_0p4=0.2661 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p6_0p8=0.5938 corrupt_frac_t_0p6_0p8=1.0000 acc_corrupt_t_0p4_0p6=0.5379 corrupt_frac_t_0p4_0p6=1.0000 loss_all=6.6078 init_gold_top10=0.1440 init_gold_top100=0.2411
313
+ step=925 epoch=463/500 epoch_step=1/2 micro_steps=925 elapsed=6.7s lr=2.000000e-03 loss=4.7291 loss_recon=4.7291 loss_meanflow=0.0000 mean_model_t=0.2037 mean_corrupt_t=0.2037 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2236 corrupt_frac=0.8550 acc_corrupt=0.1681 loss_corrupt=4.7291 wrong_frac=0.7979 init_acc_corrupt=0.1045 acc_corrupt_t_0p2_0p4=0.2907 corrupt_frac_t_0p2_0p4=1.0000 out_w_norm=9.8717 out_g_norm=4.3022 acc_corrupt_t_0p0_0p2=0.0891 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p4_0p6=0.4844 corrupt_frac_t_0p4_0p6=1.0000 loss_all=6.5996 init_gold_top10=0.0480 init_gold_top100=0.1797
314
+ step=950 epoch=475/500 epoch_step=2/2 micro_steps=950 elapsed=3.2s lr=2.000000e-03 loss=4.8115 loss_recon=4.8115 loss_meanflow=0.0000 mean_model_t=0.1864 mean_corrupt_t=0.1864 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2011 corrupt_frac=0.8275 acc_corrupt=0.1441 loss_corrupt=4.8115 wrong_frac=0.8129 init_acc_corrupt=0.0838 acc_corrupt_t_0p0_0p2=0.0880 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=9.9749 out_g_norm=4.6037 acc_corrupt_t_0p2_0p4=0.2572 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.4464 corrupt_frac_t_0p4_0p6=1.0000 loss_all=4.9126 init_gold_top10=0.0368 init_gold_top100=0.1462
315
+ step=975 epoch=488/500 epoch_step=1/2 micro_steps=975 elapsed=6.0s lr=2.000000e-03 loss=4.6119 loss_recon=4.6119 loss_meanflow=0.0000 mean_model_t=0.2108 mean_corrupt_t=0.2108 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2337 corrupt_frac=0.8600 acc_corrupt=0.1779 loss_corrupt=4.6119 wrong_frac=0.7927 init_acc_corrupt=0.1185 acc_corrupt_t_0p6_0p8=0.7279 corrupt_frac_t_0p6_0p8=1.0000 out_w_norm=10.0757 out_g_norm=4.2819 acc_corrupt_t_0p0_0p2=0.0761 corrupt_frac_t_0p0_0p2=1.0000 acc_corrupt_t_0p2_0p4=0.2466 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.5458 corrupt_frac_t_0p4_0p6=1.0000 loss_all=6.7924 init_gold_top10=0.0292 init_gold_top100=0.1271
316
+ step=1000 epoch=500/500 epoch_step=2/2 micro_steps=1000 elapsed=3.2s lr=2.000000e-03 loss=4.3429 loss_recon=4.3429 loss_meanflow=0.0000 mean_model_t=0.2305 mean_corrupt_t=0.2305 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2578 corrupt_frac=0.8625 acc_corrupt=0.2071 loss_corrupt=4.3429 wrong_frac=0.7682 init_acc_corrupt=0.1369 acc_corrupt_t_0p0_0p2=0.0725 corrupt_frac_t_0p0_0p2=1.0000 out_w_norm=10.1581 out_g_norm=3.9652 acc_corrupt_t_0p2_0p4=0.2456 corrupt_frac_t_0p2_0p4=1.0000 acc_corrupt_t_0p4_0p6=0.5545 corrupt_frac_t_0p4_0p6=1.0000 loss_all=3.6109 init_gold_top10=0.2455 init_gold_top100=0.2902
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/hf_xet-1.5.0.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ uv
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/hf_xet-1.5.0.dist-info/METADATA ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: hf-xet
3
+ Version: 1.5.0
4
+ Classifier: Development Status :: 5 - Production/Stable
5
+ Classifier: License :: OSI Approved :: Apache Software License
6
+ Classifier: Programming Language :: Rust
7
+ Classifier: Programming Language :: Python :: Implementation :: CPython
8
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3 :: Only
11
+ Classifier: Programming Language :: Python :: 3.8
12
+ Classifier: Programming Language :: Python :: 3.9
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
17
+ Classifier: Programming Language :: Python :: 3.14
18
+ Classifier: Programming Language :: Python :: Free Threading
19
+ Classifier: Programming Language :: Python :: Free Threading :: 2 - Beta
20
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
21
+ Requires-Dist: pytest ; extra == 'tests'
22
+ Provides-Extra: tests
23
+ License-File: LICENSE
24
+ Summary: Fast transfer of large files with the Hugging Face Hub.
25
+ Maintainer-email: Rajat Arya <rajat@rajatarya.com>, Jared Sulzdorf <j.sulzdorf@gmail.com>, Di Xiao <di@huggingface.co>, Assaf Vayner <assaf@huggingface.co>, Hoyt Koepke <hoytak@gmail.com>
26
+ License-Expression: Apache-2.0
27
+ Requires-Python: >=3.8
28
+ Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
29
+ Project-URL: Documentation, https://huggingface.co/docs/hub/xet/index
30
+ Project-URL: Homepage, https://github.com/huggingface/xet-core
31
+ Project-URL: Issues, https://github.com/huggingface/xet-core/issues
32
+ Project-URL: Repository, https://github.com/huggingface/xet-core.git
33
+
34
+ <!---
35
+ Copyright 2024 The HuggingFace Team. All rights reserved.
36
+
37
+ Licensed under the Apache License, Version 2.0 (the "License");
38
+ you may not use this file except in compliance with the License.
39
+ You may obtain a copy of the License at
40
+
41
+ http://www.apache.org/licenses/LICENSE-2.0
42
+
43
+ Unless required by applicable law or agreed to in writing, software
44
+ distributed under the License is distributed on an "AS IS" BASIS,
45
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
46
+ See the License for the specific language governing permissions and
47
+ limitations under the License.
48
+ -->
49
+ <p align="center">
50
+ <a href="https://github.com/huggingface/xet-core/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/xet-core.svg?color=blue"></a>
51
+ <a href="https://github.com/huggingface/xet-core/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/xet-core.svg"></a>
52
+ <a href="https://github.com/huggingface/xet-core/blob/main/CODE_OF_CONDUCT.md"><img alt="Contributor Covenant" src="https://img.shields.io/badge/Contributor%20Covenant-v2.0%20adopted-ff69b4.svg"></a>
53
+ </p>
54
+
55
+ <h3 align="center">
56
+ <p>🤗 hf-xet - xet client tech, used in <a target="_blank" href="https://github.com/huggingface/huggingface_hub/">huggingface_hub</a></p>
57
+ </h3>
58
+
59
+ ## Welcome
60
+
61
+ `hf-xet` enables `huggingface_hub` to utilize xet storage for uploading and downloading to HF Hub. Xet storage provides chunk-based deduplication, efficient storage/retrieval with local disk caching, and backwards compatibility with Git LFS. This library is not meant to be used directly, and is instead intended to be used from [huggingface_hub](https://pypi.org/project/huggingface-hub).
62
+
63
+ ## Key features
64
+
65
+ ♻ **chunk-based deduplication implementation**: avoid transferring and storing chunks that are shared across binary files (models, datasets, etc).
66
+
67
+ 🤗 **Python bindings**: bindings for [huggingface_hub](https://github.com/huggingface/huggingface_hub/) package.
68
+
69
+ ↔ **network communications**: concurrent communication to HF Hub Xet backend services (CAS).
70
+
71
+ 🔖 **local disk caching**: chunk-based cache that sits alongside the existing [huggingface_hub disk cache](https://huggingface.co/docs/huggingface_hub/guides/manage-cache).
72
+
73
+ ## Installation
74
+
75
+ Install the `hf_xet` package with [pip](https://pypi.org/project/hf-xet/):
76
+
77
+ ```bash
78
+ pip install hf_xet
79
+ ```
80
+
81
+ ## Quick Start
82
+
83
+ `hf_xet` is not intended to be run independently as it is expected to be used from `huggingface_hub`, so to get started with `huggingface_hub` check out the documentation [here]("https://hf.co/docs/huggingface_hub").
84
+
85
+ ## Contributions (feature requests, bugs, etc.) are encouraged & appreciated 💙💚💛💜🧡❤️
86
+
87
+ Please join us in making hf-xet better. We value everyone's contributions. Code is not the only way to help. Answering questions, helping each other, improving documentation, filing issues all help immensely. If you are interested in contributing (please do!), check out the [contribution guide](https://github.com/huggingface/xet-core/blob/main/CONTRIBUTING.md) for this repository.
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/hf_xet-1.5.0.dist-info/REQUESTED ADDED
File without changes
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/typer/core.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import errno
2
+ import inspect
3
+ import os
4
+ import sys
5
+ from collections.abc import Callable, MutableMapping, Sequence
6
+ from difflib import get_close_matches
7
+ from enum import Enum
8
+ from gettext import gettext as _
9
+ from typing import (
10
+ Any,
11
+ TextIO,
12
+ Union,
13
+ cast,
14
+ )
15
+
16
+ import click
17
+ import click.core
18
+ import click.formatting
19
+ import click.shell_completion
20
+ import click.types
21
+ import click.utils
22
+
23
+ from ._typing import Literal
24
+ from .utils import parse_boolean_env_var
25
+
26
+ MarkupMode = Literal["markdown", "rich", None]
27
+ MARKUP_MODE_KEY = "TYPER_RICH_MARKUP_MODE"
28
+
29
+ HAS_RICH = parse_boolean_env_var(os.getenv("TYPER_USE_RICH"), default=True)
30
+
31
+ if HAS_RICH:
32
+ DEFAULT_MARKUP_MODE: MarkupMode = "rich"
33
+ else:
34
+ DEFAULT_MARKUP_MODE = None
35
+
36
+
37
+ # Copy from click.parser._split_opt
38
+ def _split_opt(opt: str) -> tuple[str, str]:
39
+ first = opt[:1]
40
+ if first.isalnum():
41
+ return "", opt
42
+ if opt[1:2] == first:
43
+ return opt[:2], opt[2:]
44
+ return first, opt[1:]
45
+
46
+
47
+ def _typer_param_setup_autocompletion_compat(
48
+ self: click.Parameter,
49
+ *,
50
+ autocompletion: Callable[
51
+ [click.Context, list[str], str], list[tuple[str, str] | str]
52
+ ]
53
+ | None = None,
54
+ ) -> None:
55
+ if self._custom_shell_complete is not None:
56
+ import warnings
57
+
58
+ warnings.warn(
59
+ "In Typer, only the parameter 'autocompletion' is supported. "
60
+ "The support for 'shell_complete' is deprecated and will be removed in upcoming versions. ",
61
+ DeprecationWarning,
62
+ stacklevel=2,
63
+ )
64
+
65
+ if autocompletion is not None:
66
+
67
+ def compat_autocompletion(
68
+ ctx: click.Context, param: click.core.Parameter, incomplete: str
69
+ ) -> list["click.shell_completion.CompletionItem"]:
70
+ from click.shell_completion import CompletionItem
71
+
72
+ out = []
73
+
74
+ for c in autocompletion(ctx, [], incomplete):
75
+ if isinstance(c, tuple):
76
+ use_completion = CompletionItem(c[0], help=c[1])
77
+ else:
78
+ assert isinstance(c, str)
79
+ use_completion = CompletionItem(c)
80
+
81
+ if use_completion.value.startswith(incomplete):
82
+ out.append(use_completion)
83
+
84
+ return out
85
+
86
+ self._custom_shell_complete = compat_autocompletion
87
+
88
+
89
+ def _get_default_string(
90
+ obj: Union["TyperArgument", "TyperOption"],
91
+ *,
92
+ ctx: click.Context,
93
+ show_default_is_str: bool,
94
+ default_value: list[Any] | tuple[Any, ...] | str | Callable[..., Any] | Any,
95
+ ) -> str:
96
+ # Extracted from click.core.Option.get_help_record() to be reused by
97
+ # rich_utils avoiding RegEx hacks
98
+ if show_default_is_str:
99
+ default_string = f"({obj.show_default})"
100
+ elif isinstance(default_value, (list, tuple)):
101
+ default_string = ", ".join(
102
+ _get_default_string(
103
+ obj, ctx=ctx, show_default_is_str=show_default_is_str, default_value=d
104
+ )
105
+ for d in default_value
106
+ )
107
+ elif isinstance(default_value, Enum):
108
+ default_string = str(default_value.value)
109
+ elif inspect.isfunction(default_value):
110
+ default_string = _("(dynamic)")
111
+ elif isinstance(obj, TyperOption) and obj.is_bool_flag and obj.secondary_opts:
112
+ # For boolean flags that have distinct True/False opts,
113
+ # use the opt without prefix instead of the value.
114
+ # Typer override, original commented
115
+ # default_string = click.parser.split_opt(
116
+ # (self.opts if self.default else self.secondary_opts)[0]
117
+ # )[1]
118
+ if obj.default:
119
+ if obj.opts:
120
+ default_string = _split_opt(obj.opts[0])[1]
121
+ else:
122
+ default_string = str(default_value)
123
+ else:
124
+ default_string = _split_opt(obj.secondary_opts[0])[1]
125
+ # Typer override end
126
+ elif (
127
+ isinstance(obj, TyperOption)
128
+ and obj.is_bool_flag
129
+ and not obj.secondary_opts
130
+ and not default_value
131
+ ):
132
+ default_string = ""
133
+ else:
134
+ default_string = str(default_value)
135
+ return default_string
136
+
137
+
138
+ def _extract_default_help_str(
139
+ obj: Union["TyperArgument", "TyperOption"], *, ctx: click.Context
140
+ ) -> Any | Callable[[], Any] | None:
141
+ # Extracted from click.core.Option.get_help_record() to be reused by
142
+ # rich_utils avoiding RegEx hacks
143
+ # Temporarily enable resilient parsing to avoid type casting
144
+ # failing for the default. Might be possible to extend this to
145
+ # help formatting in general.
146
+ resilient = ctx.resilient_parsing
147
+ ctx.resilient_parsing = True
148
+
149
+ try:
150
+ default_value = obj.get_default(ctx, call=False)
151
+ finally:
152
+ ctx.resilient_parsing = resilient
153
+ return default_value
154
+
155
+
156
+ def _main(
157
+ self: click.Command,
158
+ *,
159
+ args: Sequence[str] | None = None,
160
+ prog_name: str | None = None,
161
+ complete_var: str | None = None,
162
+ standalone_mode: bool = True,
163
+ windows_expand_args: bool = True,
164
+ rich_markup_mode: MarkupMode = DEFAULT_MARKUP_MODE,
165
+ **extra: Any,
166
+ ) -> Any:
167
+ # Typer override, duplicated from click.main() to handle custom rich exceptions
168
+ # Verify that the environment is configured correctly, or reject
169
+ # further execution to avoid a broken script.
170
+ if args is None:
171
+ args = sys.argv[1:]
172
+
173
+ # Covered in Click tests
174
+ if os.name == "nt" and windows_expand_args: # pragma: no cover
175
+ args = click.utils._expand_args(args)
176
+ else:
177
+ args = list(args)
178
+
179
+ if prog_name is None:
180
+ prog_name = click.utils._detect_program_name()
181
+
182
+ # Process shell completion requests and exit early.
183
+ self._main_shell_completion(extra, prog_name, complete_var)
184
+
185
+ try:
186
+ try:
187
+ with self.make_context(prog_name, args, **extra) as ctx:
188
+ rv = self.invoke(ctx)
189
+ if not standalone_mode:
190
+ return rv
191
+ # it's not safe to `ctx.exit(rv)` here!
192
+ # note that `rv` may actually contain data like "1" which
193
+ # has obvious effects
194
+ # more subtle case: `rv=[None, None]` can come out of
195
+ # chained commands which all returned `None` -- so it's not
196
+ # even always obvious that `rv` indicates success/failure
197
+ # by its truthiness/falsiness
198
+ ctx.exit()
199
+ except EOFError as e:
200
+ click.echo(file=sys.stderr)
201
+ raise click.Abort() from e
202
+ except KeyboardInterrupt as e:
203
+ raise click.exceptions.Exit(130) from e
204
+ except click.ClickException as e:
205
+ if not standalone_mode:
206
+ raise
207
+ # Typer override
208
+ if HAS_RICH and rich_markup_mode is not None:
209
+ from . import rich_utils
210
+
211
+ rich_utils.rich_format_error(e)
212
+ else:
213
+ e.show()
214
+ # Typer override end
215
+ sys.exit(e.exit_code)
216
+ except OSError as e:
217
+ if e.errno == errno.EPIPE:
218
+ sys.stdout = cast(TextIO, click.utils.PacifyFlushWrapper(sys.stdout))
219
+ sys.stderr = cast(TextIO, click.utils.PacifyFlushWrapper(sys.stderr))
220
+ sys.exit(1)
221
+ else:
222
+ raise
223
+ except click.exceptions.Exit as e:
224
+ if standalone_mode:
225
+ sys.exit(e.exit_code)
226
+ else:
227
+ # in non-standalone mode, return the exit code
228
+ # note that this is only reached if `self.invoke` above raises
229
+ # an Exit explicitly -- thus bypassing the check there which
230
+ # would return its result
231
+ # the results of non-standalone execution may therefore be
232
+ # somewhat ambiguous: if there are codepaths which lead to
233
+ # `ctx.exit(1)` and to `return 1`, the caller won't be able to
234
+ # tell the difference between the two
235
+ return e.exit_code
236
+ except click.Abort:
237
+ if not standalone_mode:
238
+ raise
239
+ # Typer override
240
+ if HAS_RICH and rich_markup_mode is not None:
241
+ from . import rich_utils
242
+
243
+ rich_utils.rich_abort_error()
244
+ else:
245
+ click.echo(_("Aborted!"), file=sys.stderr)
246
+ # Typer override end
247
+ sys.exit(1)
248
+
249
+
250
+ class TyperArgument(click.core.Argument):
251
+ def __init__(
252
+ self,
253
+ *,
254
+ # Parameter
255
+ param_decls: list[str],
256
+ type: Any | None = None,
257
+ required: bool | None = None,
258
+ default: Any | None = None,
259
+ callback: Callable[..., Any] | None = None,
260
+ nargs: int | None = None,
261
+ metavar: str | None = None,
262
+ expose_value: bool = True,
263
+ is_eager: bool = False,
264
+ envvar: str | list[str] | None = None,
265
+ # Note that shell_complete is not fully supported and will be removed in future versions
266
+ # TODO: Remove shell_complete in a future version (after 0.16.0)
267
+ shell_complete: Callable[
268
+ [click.Context, click.Parameter, str],
269
+ list["click.shell_completion.CompletionItem"] | list[str],
270
+ ]
271
+ | None = None,
272
+ autocompletion: Callable[..., Any] | None = None,
273
+ # TyperArgument
274
+ show_default: bool | str = True,
275
+ show_choices: bool = True,
276
+ show_envvar: bool = True,
277
+ help: str | None = None,
278
+ hidden: bool = False,
279
+ # Rich settings
280
+ rich_help_panel: str | None = None,
281
+ ):
282
+ self.help = help
283
+ self.show_default = show_default
284
+ self.show_choices = show_choices
285
+ self.show_envvar = show_envvar
286
+ self.hidden = hidden
287
+ self.rich_help_panel = rich_help_panel
288
+
289
+ super().__init__(
290
+ param_decls=param_decls,
291
+ type=type,
292
+ required=required,
293
+ default=default,
294
+ callback=callback,
295
+ nargs=nargs,
296
+ metavar=metavar,
297
+ expose_value=expose_value,
298
+ is_eager=is_eager,
299
+ envvar=envvar,
300
+ shell_complete=shell_complete,
301
+ )
302
+ _typer_param_setup_autocompletion_compat(self, autocompletion=autocompletion)
303
+
304
+ def _get_default_string(
305
+ self,
306
+ *,
307
+ ctx: click.Context,
308
+ show_default_is_str: bool,
309
+ default_value: list[Any] | tuple[Any, ...] | str | Callable[..., Any] | Any,
310
+ ) -> str:
311
+ return _get_default_string(
312
+ self,
313
+ ctx=ctx,
314
+ show_default_is_str=show_default_is_str,
315
+ default_value=default_value,
316
+ )
317
+
318
+ def _extract_default_help_str(
319
+ self, *, ctx: click.Context
320
+ ) -> Any | Callable[[], Any] | None:
321
+ return _extract_default_help_str(self, ctx=ctx)
322
+
323
+ def get_help_record(self, ctx: click.Context) -> tuple[str, str] | None:
324
+ # Modified version of click.core.Option.get_help_record()
325
+ # to support Arguments
326
+ if self.hidden:
327
+ return None
328
+ name = self.make_metavar(ctx=ctx)
329
+ help = self.help or ""
330
+ extra = []
331
+ if self.show_envvar:
332
+ envvar = self.envvar
333
+ # allow_from_autoenv is currently not supported in Typer for CLI Arguments
334
+ if envvar is not None:
335
+ var_str = (
336
+ ", ".join(str(d) for d in envvar)
337
+ if isinstance(envvar, (list, tuple))
338
+ else envvar
339
+ )
340
+ extra.append(f"env var: {var_str}")
341
+
342
+ # Typer override:
343
+ # Extracted to _extract_default_help_str() to allow re-using it in rich_utils
344
+ default_value = self._extract_default_help_str(ctx=ctx)
345
+ # Typer override end
346
+
347
+ show_default_is_str = isinstance(self.show_default, str)
348
+
349
+ if show_default_is_str or (
350
+ default_value is not None and (self.show_default or ctx.show_default)
351
+ ):
352
+ # Typer override:
353
+ # Extracted to _get_default_string() to allow re-using it in rich_utils
354
+ default_string = self._get_default_string(
355
+ ctx=ctx,
356
+ show_default_is_str=show_default_is_str,
357
+ default_value=default_value,
358
+ )
359
+ # Typer override end
360
+ if default_string:
361
+ extra.append(_("default: {default}").format(default=default_string))
362
+ if self.required:
363
+ extra.append(_("required"))
364
+ if extra:
365
+ extra_str = "; ".join(extra)
366
+ extra_str = f"[{extra_str}]"
367
+ rich_markup_mode = None
368
+ if hasattr(ctx, "obj") and isinstance(ctx.obj, dict):
369
+ rich_markup_mode = ctx.obj.get(MARKUP_MODE_KEY, None)
370
+ if HAS_RICH and rich_markup_mode == "rich":
371
+ # This is needed for when we want to export to HTML
372
+ from . import rich_utils
373
+
374
+ extra_str = rich_utils.escape_before_html_export(extra_str)
375
+
376
+ help = f"{help} {extra_str}" if help else f"{extra_str}"
377
+ return name, help
378
+
379
+ def make_metavar(self, ctx: click.Context) -> str:
380
+ # Modified version of click.core.Argument.make_metavar()
381
+ # to include Argument name
382
+ if self.metavar is not None:
383
+ var = self.metavar
384
+ if not self.required and not var.startswith("["):
385
+ var = f"[{var}]"
386
+ return var
387
+ var = (self.name or "").upper()
388
+ if not self.required:
389
+ var = f"[{var}]"
390
+ type_var = self.type.get_metavar(self, ctx=ctx)
391
+ if type_var:
392
+ var += f":{type_var}"
393
+ if self.nargs != 1:
394
+ var += "..."
395
+ return var
396
+
397
+ def value_is_missing(self, value: Any) -> bool:
398
+ return _value_is_missing(self, value)
399
+
400
+
401
+ class TyperOption(click.core.Option):
402
+ def __init__(
403
+ self,
404
+ *,
405
+ # Parameter
406
+ param_decls: list[str],
407
+ type: click.types.ParamType | Any | None = None,
408
+ required: bool | None = None,
409
+ default: Any | None = None,
410
+ callback: Callable[..., Any] | None = None,
411
+ nargs: int | None = None,
412
+ metavar: str | None = None,
413
+ expose_value: bool = True,
414
+ is_eager: bool = False,
415
+ envvar: str | list[str] | None = None,
416
+ # Note that shell_complete is not fully supported and will be removed in future versions
417
+ # TODO: Remove shell_complete in a future version (after 0.16.0)
418
+ shell_complete: Callable[
419
+ [click.Context, click.Parameter, str],
420
+ list["click.shell_completion.CompletionItem"] | list[str],
421
+ ]
422
+ | None = None,
423
+ autocompletion: Callable[..., Any] | None = None,
424
+ # Option
425
+ show_default: bool | str = False,
426
+ prompt: bool | str = False,
427
+ confirmation_prompt: bool | str = False,
428
+ prompt_required: bool = True,
429
+ hide_input: bool = False,
430
+ is_flag: bool | None = None,
431
+ multiple: bool = False,
432
+ count: bool = False,
433
+ allow_from_autoenv: bool = True,
434
+ help: str | None = None,
435
+ hidden: bool = False,
436
+ show_choices: bool = True,
437
+ show_envvar: bool = False,
438
+ # Rich settings
439
+ rich_help_panel: str | None = None,
440
+ ):
441
+ super().__init__(
442
+ param_decls=param_decls,
443
+ type=type,
444
+ required=required,
445
+ default=default,
446
+ callback=callback,
447
+ nargs=nargs,
448
+ metavar=metavar,
449
+ expose_value=expose_value,
450
+ is_eager=is_eager,
451
+ envvar=envvar,
452
+ show_default=show_default,
453
+ prompt=prompt,
454
+ confirmation_prompt=confirmation_prompt,
455
+ hide_input=hide_input,
456
+ is_flag=is_flag,
457
+ multiple=multiple,
458
+ count=count,
459
+ allow_from_autoenv=allow_from_autoenv,
460
+ help=help,
461
+ hidden=hidden,
462
+ show_choices=show_choices,
463
+ show_envvar=show_envvar,
464
+ prompt_required=prompt_required,
465
+ shell_complete=shell_complete,
466
+ )
467
+ _typer_param_setup_autocompletion_compat(self, autocompletion=autocompletion)
468
+ self.rich_help_panel = rich_help_panel
469
+
470
+ def _get_default_string(
471
+ self,
472
+ *,
473
+ ctx: click.Context,
474
+ show_default_is_str: bool,
475
+ default_value: list[Any] | tuple[Any, ...] | str | Callable[..., Any] | Any,
476
+ ) -> str:
477
+ return _get_default_string(
478
+ self,
479
+ ctx=ctx,
480
+ show_default_is_str=show_default_is_str,
481
+ default_value=default_value,
482
+ )
483
+
484
+ def _extract_default_help_str(
485
+ self, *, ctx: click.Context
486
+ ) -> Any | Callable[[], Any] | None:
487
+ return _extract_default_help_str(self, ctx=ctx)
488
+
489
+ def make_metavar(self, ctx: click.Context) -> str:
490
+ return super().make_metavar(ctx=ctx)
491
+
492
+ def get_help_record(self, ctx: click.Context) -> tuple[str, str] | None:
493
+ # Duplicate all of Click's logic only to modify a single line, to allow boolean
494
+ # flags with only names for False values as it's currently supported by Typer
495
+ # Ref: https://typer.tiangolo.com/tutorial/parameter-types/bool/#only-names-for-false
496
+ if self.hidden:
497
+ return None
498
+
499
+ any_prefix_is_slash = False
500
+
501
+ def _write_opts(opts: Sequence[str]) -> str:
502
+ nonlocal any_prefix_is_slash
503
+
504
+ rv, any_slashes = click.formatting.join_options(opts)
505
+
506
+ if any_slashes:
507
+ any_prefix_is_slash = True
508
+
509
+ if not self.is_flag and not self.count:
510
+ rv += f" {self.make_metavar(ctx=ctx)}"
511
+
512
+ return rv
513
+
514
+ rv = [_write_opts(self.opts)]
515
+
516
+ if self.secondary_opts:
517
+ rv.append(_write_opts(self.secondary_opts))
518
+
519
+ help = self.help or ""
520
+ extra = []
521
+
522
+ if self.show_envvar:
523
+ envvar = self.envvar
524
+
525
+ if envvar is None:
526
+ if (
527
+ self.allow_from_autoenv
528
+ and ctx.auto_envvar_prefix is not None
529
+ and self.name is not None
530
+ ):
531
+ envvar = f"{ctx.auto_envvar_prefix}_{self.name.upper()}"
532
+
533
+ if envvar is not None:
534
+ var_str = (
535
+ envvar
536
+ if isinstance(envvar, str)
537
+ else ", ".join(str(d) for d in envvar)
538
+ )
539
+ extra.append(_("env var: {var}").format(var=var_str))
540
+
541
+ # Typer override:
542
+ # Extracted to _extract_default() to allow re-using it in rich_utils
543
+ default_value = self._extract_default_help_str(ctx=ctx)
544
+ # Typer override end
545
+
546
+ show_default_is_str = isinstance(self.show_default, str)
547
+
548
+ if show_default_is_str or (
549
+ default_value is not None and (self.show_default or ctx.show_default)
550
+ ):
551
+ # Typer override:
552
+ # Extracted to _get_default_string() to allow re-using it in rich_utils
553
+ default_string = self._get_default_string(
554
+ ctx=ctx,
555
+ show_default_is_str=show_default_is_str,
556
+ default_value=default_value,
557
+ )
558
+ # Typer override end
559
+ if default_string:
560
+ extra.append(_("default: {default}").format(default=default_string))
561
+
562
+ if isinstance(self.type, click.types._NumberRangeBase):
563
+ range_str = self.type._describe_range()
564
+
565
+ if range_str:
566
+ extra.append(range_str)
567
+
568
+ if self.required:
569
+ extra.append(_("required"))
570
+
571
+ if extra:
572
+ extra_str = "; ".join(extra)
573
+ extra_str = f"[{extra_str}]"
574
+ rich_markup_mode = None
575
+ if hasattr(ctx, "obj") and isinstance(ctx.obj, dict):
576
+ rich_markup_mode = ctx.obj.get(MARKUP_MODE_KEY, None)
577
+ if HAS_RICH and rich_markup_mode == "rich":
578
+ # This is needed for when we want to export to HTML
579
+ from . import rich_utils
580
+
581
+ extra_str = rich_utils.escape_before_html_export(extra_str)
582
+
583
+ help = f"{help} {extra_str}" if help else f"{extra_str}"
584
+
585
+ return ("; " if any_prefix_is_slash else " / ").join(rv), help
586
+
587
+ def value_is_missing(self, value: Any) -> bool:
588
+ return _value_is_missing(self, value)
589
+
590
+
591
+ def _value_is_missing(param: click.Parameter, value: Any) -> bool:
592
+ if value is None:
593
+ return True
594
+
595
+ # Click 8.3 and beyond
596
+ # if value is UNSET:
597
+ # return True
598
+
599
+ if (param.nargs != 1 or param.multiple) and value == ():
600
+ return True # pragma: no cover
601
+
602
+ return False
603
+
604
+
605
+ def _typer_format_options(
606
+ self: click.core.Command, *, ctx: click.Context, formatter: click.HelpFormatter
607
+ ) -> None:
608
+ args = []
609
+ opts = []
610
+ for param in self.get_params(ctx):
611
+ rv = param.get_help_record(ctx)
612
+ if rv is not None:
613
+ if param.param_type_name == "argument":
614
+ args.append(rv)
615
+ elif param.param_type_name == "option":
616
+ opts.append(rv)
617
+
618
+ if args:
619
+ with formatter.section(_("Arguments")):
620
+ formatter.write_dl(args)
621
+ if opts:
622
+ with formatter.section(_("Options")):
623
+ formatter.write_dl(opts)
624
+
625
+
626
+ def _typer_main_shell_completion(
627
+ self: click.core.Command,
628
+ *,
629
+ ctx_args: MutableMapping[str, Any],
630
+ prog_name: str,
631
+ complete_var: str | None = None,
632
+ ) -> None:
633
+ if complete_var is None:
634
+ complete_var = f"_{prog_name}_COMPLETE".replace("-", "_").upper()
635
+
636
+ instruction = os.environ.get(complete_var)
637
+
638
+ if not instruction:
639
+ return
640
+
641
+ from .completion import shell_complete
642
+
643
+ rv = shell_complete(self, ctx_args, prog_name, complete_var, instruction)
644
+ sys.exit(rv)
645
+
646
+
647
+ class TyperCommand(click.core.Command):
648
+ def __init__(
649
+ self,
650
+ name: str | None,
651
+ *,
652
+ context_settings: dict[str, Any] | None = None,
653
+ callback: Callable[..., Any] | None = None,
654
+ params: list[click.Parameter] | None = None,
655
+ help: str | None = None,
656
+ epilog: str | None = None,
657
+ short_help: str | None = None,
658
+ options_metavar: str | None = "[OPTIONS]",
659
+ add_help_option: bool = True,
660
+ no_args_is_help: bool = False,
661
+ hidden: bool = False,
662
+ deprecated: bool = False,
663
+ # Rich settings
664
+ rich_markup_mode: MarkupMode = DEFAULT_MARKUP_MODE,
665
+ rich_help_panel: str | None = None,
666
+ ) -> None:
667
+ super().__init__(
668
+ name=name,
669
+ context_settings=context_settings,
670
+ callback=callback,
671
+ params=params,
672
+ help=help,
673
+ epilog=epilog,
674
+ short_help=short_help,
675
+ options_metavar=options_metavar,
676
+ add_help_option=add_help_option,
677
+ no_args_is_help=no_args_is_help,
678
+ hidden=hidden,
679
+ deprecated=deprecated,
680
+ )
681
+ self.rich_markup_mode: MarkupMode = rich_markup_mode
682
+ self.rich_help_panel = rich_help_panel
683
+
684
+ def format_options(
685
+ self, ctx: click.Context, formatter: click.HelpFormatter
686
+ ) -> None:
687
+ _typer_format_options(self, ctx=ctx, formatter=formatter)
688
+
689
+ def _main_shell_completion(
690
+ self,
691
+ ctx_args: MutableMapping[str, Any],
692
+ prog_name: str,
693
+ complete_var: str | None = None,
694
+ ) -> None:
695
+ _typer_main_shell_completion(
696
+ self, ctx_args=ctx_args, prog_name=prog_name, complete_var=complete_var
697
+ )
698
+
699
+ def main(
700
+ self,
701
+ args: Sequence[str] | None = None,
702
+ prog_name: str | None = None,
703
+ complete_var: str | None = None,
704
+ standalone_mode: bool = True,
705
+ windows_expand_args: bool = True,
706
+ **extra: Any,
707
+ ) -> Any:
708
+ return _main(
709
+ self,
710
+ args=args,
711
+ prog_name=prog_name,
712
+ complete_var=complete_var,
713
+ standalone_mode=standalone_mode,
714
+ windows_expand_args=windows_expand_args,
715
+ rich_markup_mode=self.rich_markup_mode,
716
+ **extra,
717
+ )
718
+
719
+ def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
720
+ if not HAS_RICH or self.rich_markup_mode is None:
721
+ if not hasattr(ctx, "obj") or ctx.obj is None:
722
+ ctx.ensure_object(dict)
723
+ if isinstance(ctx.obj, dict):
724
+ ctx.obj[MARKUP_MODE_KEY] = self.rich_markup_mode
725
+ return super().format_help(ctx, formatter)
726
+ from . import rich_utils
727
+
728
+ return rich_utils.rich_format_help(
729
+ obj=self,
730
+ ctx=ctx,
731
+ markup_mode=self.rich_markup_mode,
732
+ )
733
+
734
+
735
+ class TyperGroup(click.core.Group):
736
+ def __init__(
737
+ self,
738
+ *,
739
+ name: str | None = None,
740
+ commands: dict[str, click.Command] | Sequence[click.Command] | None = None,
741
+ # Rich settings
742
+ rich_markup_mode: MarkupMode = DEFAULT_MARKUP_MODE,
743
+ rich_help_panel: str | None = None,
744
+ suggest_commands: bool = True,
745
+ **attrs: Any,
746
+ ) -> None:
747
+ super().__init__(name=name, commands=commands, **attrs)
748
+ self.rich_markup_mode: MarkupMode = rich_markup_mode
749
+ self.rich_help_panel = rich_help_panel
750
+ self.suggest_commands = suggest_commands
751
+
752
+ def format_options(
753
+ self, ctx: click.Context, formatter: click.HelpFormatter
754
+ ) -> None:
755
+ _typer_format_options(self, ctx=ctx, formatter=formatter)
756
+ self.format_commands(ctx, formatter)
757
+
758
+ def _main_shell_completion(
759
+ self,
760
+ ctx_args: MutableMapping[str, Any],
761
+ prog_name: str,
762
+ complete_var: str | None = None,
763
+ ) -> None:
764
+ _typer_main_shell_completion(
765
+ self, ctx_args=ctx_args, prog_name=prog_name, complete_var=complete_var
766
+ )
767
+
768
+ def resolve_command(
769
+ self, ctx: click.Context, args: list[str]
770
+ ) -> tuple[str | None, click.Command | None, list[str]]:
771
+ try:
772
+ return super().resolve_command(ctx, args)
773
+ except click.UsageError as e:
774
+ if self.suggest_commands:
775
+ available_commands = list(self.commands.keys())
776
+ if available_commands and args:
777
+ typo = args[0]
778
+ matches = get_close_matches(typo, available_commands)
779
+ if matches:
780
+ suggestions = ", ".join(f"{m!r}" for m in matches)
781
+ message = e.message.rstrip(".")
782
+ e.message = f"{message}. Did you mean {suggestions}?"
783
+ raise
784
+
785
+ def main(
786
+ self,
787
+ args: Sequence[str] | None = None,
788
+ prog_name: str | None = None,
789
+ complete_var: str | None = None,
790
+ standalone_mode: bool = True,
791
+ windows_expand_args: bool = True,
792
+ **extra: Any,
793
+ ) -> Any:
794
+ return _main(
795
+ self,
796
+ args=args,
797
+ prog_name=prog_name,
798
+ complete_var=complete_var,
799
+ standalone_mode=standalone_mode,
800
+ windows_expand_args=windows_expand_args,
801
+ rich_markup_mode=self.rich_markup_mode,
802
+ **extra,
803
+ )
804
+
805
+ def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
806
+ if not HAS_RICH or self.rich_markup_mode is None:
807
+ return super().format_help(ctx, formatter)
808
+ from . import rich_utils
809
+
810
+ return rich_utils.rich_format_help(
811
+ obj=self,
812
+ ctx=ctx,
813
+ markup_mode=self.rich_markup_mode,
814
+ )
815
+
816
+ def list_commands(self, ctx: click.Context) -> list[str]:
817
+ """Returns a list of subcommand names.
818
+ Note that in Click's Group class, these are sorted.
819
+ In Typer, we wish to maintain the original order of creation (cf Issue #933)"""
820
+ return [n for n, c in self.commands.items()]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/typer/py.typed ADDED
File without changes
LTA_openwebtext_dualt/scripts/dirichlet_support_decode_probe.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ import math
7
+ import re
8
+ import sys
9
+ from collections import Counter
10
+ from pathlib import Path
11
+ from typing import Sequence
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ REPO_ROOT = Path(__file__).resolve().parents[1]
17
+ if str(REPO_ROOT) not in sys.path:
18
+ sys.path.insert(0, str(REPO_ROOT))
19
+
20
+ from eval import build_model_from_ckpt
21
+ from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model
22
+ from flowtext_lab.tokenization import BpeTextTokenizer
23
+
24
+
25
+ WORD_RE = re.compile(r"[A-Za-z]+|\d+|[^\sA-Za-z\d]")
26
+
27
+
28
+ def encode_prompt(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]:
29
+ return list(tokenizer.tokenizer.encode(prompt, add_special_tokens=False).ids)[:max_len]
30
+
31
+
32
+ def decode_text(tokenizer: BpeTextTokenizer, ids: Sequence[int]) -> str:
33
+ return tokenizer.decode(ids, stop_at_eos=False, skip_special_tokens=False)
34
+
35
+
36
+ def text_metrics(text: str) -> dict[str, float]:
37
+ toks = WORD_RE.findall(text)
38
+ words = [t.lower() for t in toks if re.fullmatch(r"[A-Za-z]+", t)]
39
+ n_tok = max(len(toks), 1)
40
+ n_words = max(len(words), 1)
41
+ wc = Counter(words)
42
+ max_word_frac = wc.most_common(1)[0][1] / n_words if wc else 1.0
43
+ grams3 = list(zip(toks, toks[1:], toks[2:]))
44
+ rep3 = sum(v - 1 for v in Counter(grams3).values() if v > 1) / max(len(grams3), 1)
45
+ bigrams = list(zip(words, words[1:]))
46
+ distinct2 = len(set(bigrams)) / max(len(bigrams), 1) if bigrams else 0.0
47
+ punct_frac = sum(bool(re.fullmatch(r"[,.;:!?]+", t)) for t in toks) / n_tok
48
+ digit_frac = sum(t.isdigit() for t in toks) / n_tok
49
+ quality = (
50
+ min(len(text) / 700.0, 1.0)
51
+ + 0.35 * distinct2
52
+ - 2.6 * rep3
53
+ - 1.2 * max_word_frac
54
+ - 0.8 * punct_frac
55
+ - 1.0 * digit_frac
56
+ - 0.2 * text.count("<|endoftext|>")
57
+ - 0.5 * text.count("�")
58
+ )
59
+ return {
60
+ "quality": float(quality),
61
+ "chars": float(len(text)),
62
+ "words": float(len(words)),
63
+ "rep3": float(rep3),
64
+ "distinct2": float(distinct2),
65
+ "punct_frac": float(punct_frac),
66
+ "max_word_frac": float(max_word_frac),
67
+ "eot_count": float(text.count("<|endoftext|>")),
68
+ }
69
+
70
+
71
+ def dirichlet_mean(endpoint: torch.Tensor, support_t: float, eps: float) -> torch.Tensor:
72
+ vocab = endpoint.size(-1)
73
+ mean = (1.0 - support_t) / float(vocab) + support_t * endpoint
74
+ mean = mean.clamp_min(eps)
75
+ return mean / mean.sum(dim=-1, keepdim=True).clamp_min(eps)
76
+
77
+
78
+ def total_concentration(support_t: float, c_min: float, c_max: float) -> float:
79
+ log_min = math.log(max(c_min, 1e-8))
80
+ log_max = math.log(max(c_max, c_min))
81
+ return math.exp(log_min + support_t * (log_max - log_min))
82
+
83
+
84
+ def dirichlet_resample(mean: torch.Tensor, support_t: float, c_min: float, c_max: float, eps: float) -> torch.Tensor:
85
+ conc = total_concentration(support_t, c_min, c_max)
86
+ alpha = (mean * conc).clamp_min(eps)
87
+ sample = torch._standard_gamma(alpha).clamp_min(eps)
88
+ return sample / sample.sum(dim=-1, keepdim=True).clamp_min(eps)
89
+
90
+
91
+ def schedule_power(step: int, steps: int, power: float) -> float:
92
+ base = (step + 1) / max(steps, 1)
93
+ return float(max(0.0, min(1.0, base ** float(power))))
94
+
95
+
96
+ def current_anchor(probs: torch.Tensor, mode: str, eps: float) -> torch.Tensor:
97
+ if mode == "state":
98
+ return probs
99
+ if mode == "onehot":
100
+ ids = probs.argmax(dim=-1)
101
+ return F.one_hot(ids, probs.size(-1)).to(dtype=probs.dtype, device=probs.device)
102
+ if mode == "sqrt_state":
103
+ x = probs.clamp_min(eps).sqrt()
104
+ return x / x.sum(dim=-1, keepdim=True).clamp_min(eps)
105
+ raise ValueError(f"unknown anchor mode: {mode}")
106
+
107
+
108
+ @torch.no_grad()
109
+ def build_initial(
110
+ tokenizer: BpeTextTokenizer,
111
+ prompts: list[str],
112
+ restarts: int,
113
+ max_len: int,
114
+ eps: float,
115
+ noise_init: str,
116
+ target_prob: float,
117
+ noise_sigma: float,
118
+ dirichlet_concentration: float,
119
+ lock_bos: bool,
120
+ lock_final_eos: bool,
121
+ device: torch.device,
122
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str]]:
123
+ expanded = []
124
+ prompt_ids = []
125
+ for prompt in prompts:
126
+ ids = encode_prompt(tokenizer, prompt, max_len)
127
+ if lock_bos:
128
+ ids = [tokenizer.bos_id] + ids
129
+ ids = ids[:max_len]
130
+ for _ in range(restarts):
131
+ expanded.append(prompt)
132
+ prompt_ids.append(ids)
133
+ batch = len(prompt_ids)
134
+ probs = sample_noise_simplex(
135
+ (batch, max_len),
136
+ tokenizer.vocab_size,
137
+ device,
138
+ eps,
139
+ noise_mode=noise_init,
140
+ target_prob=target_prob,
141
+ noise_sigma=noise_sigma,
142
+ dirichlet_concentration=dirichlet_concentration,
143
+ )
144
+ lock = torch.zeros((batch, max_len), dtype=torch.bool, device=device)
145
+ lock_probs = torch.zeros((batch, max_len, tokenizer.vocab_size), dtype=torch.float32, device=device)
146
+ for row, ids in enumerate(prompt_ids):
147
+ if not ids:
148
+ continue
149
+ ids_t = torch.tensor(ids, dtype=torch.long, device=device)
150
+ onehot = F.one_hot(ids_t, tokenizer.vocab_size).float()
151
+ probs[row, : len(ids)] = onehot
152
+ lock_probs[row, : len(ids)] = onehot
153
+ lock[row, : len(ids)] = True
154
+ if lock_final_eos:
155
+ eos = torch.tensor([tokenizer.eos_id], dtype=torch.long, device=device)
156
+ eos_prob = F.one_hot(eos, tokenizer.vocab_size).float()[0]
157
+ probs[:, -1] = eos_prob
158
+ lock_probs[:, -1] = eos_prob
159
+ lock[:, -1] = True
160
+ attn = torch.ones((batch, max_len), dtype=torch.bool, device=device)
161
+ return probs, lock, lock_probs, attn, expanded
162
+
163
+
164
+ @torch.no_grad()
165
+ def decode_one_config(
166
+ model,
167
+ tokenizer,
168
+ init,
169
+ lock,
170
+ lock_probs,
171
+ attn,
172
+ args,
173
+ update: str,
174
+ final_from: str,
175
+ temp: float,
176
+ model_t_mode: str,
177
+ support_power: float,
178
+ semantic_power: float,
179
+ anchor_mode: str,
180
+ ):
181
+ probs = init.clone()
182
+ last_endpoint = probs
183
+ device = probs.device
184
+ for step in range(args.steps):
185
+ model_t = model_time_for_step(model_t_mode, step, args.steps, probs.size(0), device, dtype=torch.float32)
186
+ logits = model(state_for_model(model, probs, args.eps), model_t, attn).float() / temp
187
+ endpoint = F.softmax(logits, dim=-1)
188
+ last_endpoint = endpoint
189
+ support_t = schedule_power(step, args.steps, support_power)
190
+ semantic_t = schedule_power(step, args.steps, semantic_power)
191
+ if update.startswith("dual_line"):
192
+ anchor = current_anchor(probs, anchor_mode, args.eps)
193
+ forward_endpoint = (1.0 - semantic_t) * anchor + semantic_t * endpoint
194
+ forward_endpoint = forward_endpoint / forward_endpoint.sum(dim=-1, keepdim=True).clamp_min(args.eps)
195
+ else:
196
+ forward_endpoint = endpoint
197
+ mean = dirichlet_mean(forward_endpoint, support_t, args.eps)
198
+ if update == "mean":
199
+ new_probs = mean
200
+ elif update == "resample":
201
+ new_probs = dirichlet_resample(mean, support_t, args.concentration_min, args.concentration_max, args.eps)
202
+ elif update == "dual_line_mean":
203
+ new_probs = mean
204
+ elif update == "dual_line_resample":
205
+ new_probs = dirichlet_resample(mean, support_t, args.concentration_min, args.concentration_max, args.eps)
206
+ elif update == "ema_mean":
207
+ gamma = 1.0 / max(args.steps - step, 1)
208
+ new_probs = (1.0 - gamma) * probs + gamma * mean
209
+ new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(args.eps)
210
+ else:
211
+ raise ValueError(update)
212
+ probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs)
213
+ if final_from == "endpoint":
214
+ out = last_endpoint
215
+ elif final_from == "blend":
216
+ out = 0.5 * probs + 0.5 * last_endpoint
217
+ else:
218
+ out = probs
219
+ out = torch.where(lock.unsqueeze(-1), lock_probs, out)
220
+ return out / out.sum(dim=-1, keepdim=True).clamp_min(args.eps)
221
+
222
+
223
+ def main() -> None:
224
+ ap = argparse.ArgumentParser()
225
+ ap.add_argument("--checkpoint", required=True)
226
+ ap.add_argument("--tokenizer_path", required=True)
227
+ ap.add_argument("--output", required=True)
228
+ ap.add_argument("--max_len", type=int, default=256)
229
+ ap.add_argument("--steps", type=int, default=256)
230
+ ap.add_argument("--restarts", type=int, default=4)
231
+ ap.add_argument("--prompts", nargs="+", default=[""])
232
+ ap.add_argument("--noise_init", default="dirichlet")
233
+ ap.add_argument("--target_prob", type=float, default=0.99)
234
+ ap.add_argument("--noise_sigma", type=float, default=-1.0)
235
+ ap.add_argument("--dirichlet_init_concentration", type=float, default=1.0)
236
+ ap.add_argument("--concentration_min", type=float, default=1.0)
237
+ ap.add_argument("--concentration_max", type=float, default=1024.0)
238
+ ap.add_argument("--updates", nargs="+", default=["mean", "ema_mean", "resample"])
239
+ ap.add_argument("--finals", nargs="+", default=["state", "endpoint", "blend"])
240
+ ap.add_argument("--temps", nargs="+", type=float, default=[1.0, 1.2, 1.35])
241
+ ap.add_argument("--model_t_modes", nargs="+", default=["flow", "const05"])
242
+ ap.add_argument("--support_powers", nargs="+", type=float, default=[1.0])
243
+ ap.add_argument("--semantic_powers", nargs="+", type=float, default=[1.0])
244
+ ap.add_argument("--anchor_modes", nargs="+", default=["onehot"])
245
+ ap.add_argument("--lock_bos", action="store_true")
246
+ ap.add_argument("--lock_final_eos", action="store_true")
247
+ ap.add_argument("--eps", type=float, default=1e-8)
248
+ ap.add_argument("--seed", type=int, default=1234)
249
+ args = ap.parse_args()
250
+
251
+ torch.manual_seed(args.seed)
252
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
253
+ tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path)
254
+ ckpt = torch.load(args.checkpoint, map_location=device)
255
+ model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device)
256
+ model.eval()
257
+ init, lock, lock_probs, attn, expanded = build_initial(
258
+ tokenizer,
259
+ args.prompts,
260
+ args.restarts,
261
+ args.max_len,
262
+ args.eps,
263
+ args.noise_init,
264
+ args.target_prob,
265
+ args.noise_sigma,
266
+ args.dirichlet_init_concentration,
267
+ args.lock_bos,
268
+ args.lock_final_eos,
269
+ device,
270
+ )
271
+
272
+ configs = []
273
+ for update in args.updates:
274
+ for final_from in args.finals:
275
+ for temp in args.temps:
276
+ for model_t_mode in args.model_t_modes:
277
+ for support_power in args.support_powers:
278
+ for semantic_power in args.semantic_powers:
279
+ for anchor_mode in args.anchor_modes:
280
+ configs.append((update, final_from, temp, model_t_mode, support_power, semantic_power, anchor_mode))
281
+
282
+ out_path = Path(args.output)
283
+ out_path.parent.mkdir(parents=True, exist_ok=True)
284
+ rows = []
285
+ with out_path.open("w") as f:
286
+ for update, final_from, temp, model_t_mode, support_power, semantic_power, anchor_mode in configs:
287
+ probs = decode_one_config(
288
+ model,
289
+ tokenizer,
290
+ init,
291
+ lock,
292
+ lock_probs,
293
+ attn,
294
+ args,
295
+ update,
296
+ final_from,
297
+ temp,
298
+ model_t_mode,
299
+ support_power,
300
+ semantic_power,
301
+ anchor_mode,
302
+ )
303
+ ids = probs.argmax(dim=-1).detach().cpu().tolist()
304
+ texts = [decode_text(tokenizer, row) for row in ids]
305
+ mets = [text_metrics(t) for t in texts]
306
+ mean_q = sum(m["quality"] for m in mets) / len(mets)
307
+ best_i = max(range(len(texts)), key=lambda i: mets[i]["quality"])
308
+ row = {
309
+ "update": update,
310
+ "final_from": final_from,
311
+ "temp": temp,
312
+ "model_t_mode": model_t_mode,
313
+ "support_power": support_power,
314
+ "semantic_power": semantic_power,
315
+ "anchor_mode": anchor_mode,
316
+ "mean_quality": mean_q,
317
+ "best_prompt": expanded[best_i],
318
+ "best_metrics": mets[best_i],
319
+ "best_text": texts[best_i],
320
+ }
321
+ rows.append(row)
322
+ print(
323
+ "\n====",
324
+ update,
325
+ final_from,
326
+ temp,
327
+ model_t_mode,
328
+ "support_p",
329
+ support_power,
330
+ "semantic_p",
331
+ semantic_power,
332
+ "anchor",
333
+ anchor_mode,
334
+ "mean_q",
335
+ round(mean_q, 4),
336
+ flush=True,
337
+ )
338
+ print(texts[best_i][:1600], flush=True)
339
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
340
+ f.flush()
341
+ best = max(rows, key=lambda r: r["mean_quality"])
342
+ print("\nBEST", json.dumps({k: best[k] for k in best if k != "best_text"}, ensure_ascii=False, indent=2), flush=True)
343
+ print(best["best_text"], flush=True)
344
+
345
+
346
+ if __name__ == "__main__":
347
+ main()
LTA_openwebtext_dualt/scripts/infer_softkl_decode_probe.py.bak_lognsr_gumbel_20260519 ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import math
6
+ import sys
7
+ from dataclasses import asdict, is_dataclass
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+
14
+ REPO_ROOT = Path(__file__).resolve().parents[1]
15
+ if str(REPO_ROOT) not in sys.path:
16
+ sys.path.insert(0, str(REPO_ROOT))
17
+
18
+ from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model
19
+ from flowtext_lab.genppl import filter_generated_texts, summarize_token_diversity
20
+ from flowtext_lab.model import EndpointPredictor
21
+ from flowtext_lab.tokenization import BpeTextTokenizer
22
+
23
+
24
+ def extend_pos_embed(sd: dict[str, torch.Tensor], max_len: int, mode: str) -> dict[str, torch.Tensor]:
25
+ sd = dict(sd)
26
+ key = "pos_embed"
27
+ if key not in sd:
28
+ return sd
29
+ pos = sd[key]
30
+ old_len = int(pos.size(1))
31
+ if old_len == max_len:
32
+ return sd
33
+ if old_len > max_len:
34
+ sd[key] = pos[:, :max_len].contiguous()
35
+ return sd
36
+ if mode == "repeat":
37
+ reps = math.ceil(max_len / old_len)
38
+ sd[key] = pos.repeat(1, reps, 1)[:, :max_len].contiguous()
39
+ elif mode == "interpolate":
40
+ x = pos.transpose(1, 2)
41
+ y = F.interpolate(x, size=max_len, mode="linear", align_corners=True)
42
+ sd[key] = y.transpose(1, 2).contiguous()
43
+ else:
44
+ raise ValueError(f"unknown pos_extend={mode}")
45
+ return sd
46
+
47
+
48
+ def build_model(ckpt: dict, tokenizer: BpeTextTokenizer, max_len: int, device: torch.device, pos_extend: str) -> EndpointPredictor:
49
+ a = ckpt.get("args", {})
50
+ ckpt_state = ckpt["model"]
51
+ if "output_bias" in a:
52
+ output_bias = bool(a["output_bias"])
53
+ else:
54
+ output_bias = "output_layer.linear.bias" in ckpt_state or "out_proj.bias" in ckpt_state
55
+ model = EndpointPredictor(
56
+ vocab_size=tokenizer.vocab_size,
57
+ max_len=max_len,
58
+ d_model=int(a.get("d_model", 768)),
59
+ n_heads=int(a.get("n_heads", 12)),
60
+ n_layers=int(a.get("n_layers", 12)),
61
+ dim_ff=int(a.get("dim_ff", 3072)),
62
+ dropout=0.0,
63
+ model_type=str(a.get("model_type", "ddit")),
64
+ cond_dim=int(a.get("cond_dim", 128)),
65
+ input_format=str(a.get("state_format", a.get("input_format", "prob"))),
66
+ output_bias=output_bias,
67
+ norm_type=str(a.get("norm_type", "layernorm")),
68
+ elf_num_time_tokens=int(a.get("elf_num_time_tokens", 4)),
69
+ elf_num_model_mode_tokens=int(a.get("elf_num_model_mode_tokens", 0)),
70
+ qk_norm=bool(a.get("qk_norm", True)),
71
+ output_init_std=a.get("output_init_std", None),
72
+ ddit_mlp_type=str(a.get("ddit_mlp_type", "gelu")),
73
+ ).to(device)
74
+ state = extend_pos_embed(ckpt_state, max_len=max_len, mode=pos_extend)
75
+ model.load_state_dict(state, strict=True)
76
+ model.eval()
77
+ return model
78
+
79
+
80
+ def total_concentration(support_t: float, c_min: float, c_max: float) -> float:
81
+ return math.exp(math.log(max(c_min, 1e-8)) + support_t * (math.log(max(c_max, c_min)) - math.log(max(c_min, 1e-8))))
82
+
83
+
84
+ def dirichlet_path_mean(endpoint: torch.Tensor, support_t: float, eps: float) -> torch.Tensor:
85
+ vocab = endpoint.size(-1)
86
+ mean = (1.0 - support_t) / float(vocab) + support_t * endpoint
87
+ mean = mean.clamp_min(eps)
88
+ return mean / mean.sum(dim=-1, keepdim=True).clamp_min(eps)
89
+
90
+
91
+ def dirichlet_resample(mean: torch.Tensor, support_t: float, c_min: float, c_max: float, eps: float) -> torch.Tensor:
92
+ conc = total_concentration(support_t, c_min, c_max)
93
+ alpha = (mean * conc).clamp_min(eps)
94
+ sample = torch._standard_gamma(alpha).clamp_min(eps)
95
+ return sample / sample.sum(dim=-1, keepdim=True).clamp_min(eps)
96
+
97
+
98
+ def current_anchor(probs: torch.Tensor, mode: str, eps: float) -> torch.Tensor:
99
+ if mode == "onehot":
100
+ return F.one_hot(probs.argmax(dim=-1), probs.size(-1)).to(dtype=probs.dtype)
101
+ if mode == "sqrt_state":
102
+ anchor = probs.clamp_min(eps).sqrt()
103
+ else:
104
+ anchor = probs.clamp_min(eps)
105
+ return anchor / anchor.sum(dim=-1, keepdim=True).clamp_min(eps)
106
+
107
+
108
+ def log_geodesic_mix(p: torch.Tensor, q: torch.Tensor, gamma: float, eps: float) -> torch.Tensor:
109
+ log_mix = (1.0 - gamma) * p.clamp_min(eps).log() + gamma * q.clamp_min(eps).log()
110
+ return torch.softmax(log_mix, dim=-1)
111
+
112
+
113
+ def sqrt_geodesic_mix(p: torch.Tensor, q: torch.Tensor, gamma: float, eps: float) -> torch.Tensor:
114
+ root = (1.0 - gamma) * p.clamp_min(eps).sqrt() + gamma * q.clamp_min(eps).sqrt()
115
+ out = root.square().clamp_min(eps)
116
+ return out / out.sum(dim=-1, keepdim=True).clamp_min(eps)
117
+
118
+
119
+ def fisher_rao_mix(p: torch.Tensor, q: torch.Tensor, gamma: float, eps: float) -> torch.Tensor:
120
+ a = p.clamp_min(eps).sqrt()
121
+ b = q.clamp_min(eps).sqrt()
122
+ dot = (a * b).sum(dim=-1, keepdim=True).clamp(-1.0 + 1e-6, 1.0 - 1e-6)
123
+ theta = torch.acos(dot)
124
+ sin_theta = torch.sin(theta).clamp_min(1e-6)
125
+ left = torch.sin((1.0 - gamma) * theta) / sin_theta
126
+ right = torch.sin(gamma * theta) / sin_theta
127
+ root = left * a + right * b
128
+ out = root.square().clamp_min(eps)
129
+ return out / out.sum(dim=-1, keepdim=True).clamp_min(eps)
130
+
131
+
132
+ def simplex_mix(p: torch.Tensor, q: torch.Tensor, gamma: float, eps: float, geometry: str) -> torch.Tensor:
133
+ if geometry == "log":
134
+ return log_geodesic_mix(p, q, gamma, eps)
135
+ if geometry == "sqrt":
136
+ return sqrt_geodesic_mix(p, q, gamma, eps)
137
+ if geometry == "fisher":
138
+ return fisher_rao_mix(p, q, gamma, eps)
139
+ if geometry == "linear":
140
+ out = (1.0 - gamma) * p + gamma * q
141
+ out = out.clamp_min(eps)
142
+ return out / out.sum(dim=-1, keepdim=True).clamp_min(eps)
143
+ raise ValueError(geometry)
144
+
145
+
146
+ def temperature(step: int, steps: int, early: float, late: float, temp_end: float, power: float) -> float:
147
+ progress = step / max(steps, 1)
148
+ if progress >= temp_end:
149
+ return late
150
+ rel = 1.0 - progress / max(temp_end, 1e-8)
151
+ return late + (early - late) * (rel ** power)
152
+
153
+
154
+ def make_time_grid(
155
+ steps: int,
156
+ *,
157
+ schedule: str,
158
+ logit_mean: float,
159
+ logit_std: float,
160
+ power: float,
161
+ seed: int,
162
+ device: torch.device,
163
+ ) -> torch.Tensor:
164
+ if steps <= 0:
165
+ raise ValueError(f"steps must be positive, got {steps}")
166
+ if schedule == "uniform":
167
+ return torch.linspace(0.0, 1.0, steps + 1, device=device, dtype=torch.float32)
168
+ if schedule == "logit_normal":
169
+ if steps == 1:
170
+ return torch.tensor([0.0, 1.0], device=device, dtype=torch.float32)
171
+ generator = torch.Generator(device="cpu")
172
+ generator.manual_seed(int(seed))
173
+ z = torch.randn((steps - 1,), generator=generator, dtype=torch.float32)
174
+ middle = torch.sigmoid(z * float(logit_std) + float(logit_mean)).sort().values.to(device)
175
+ return torch.cat(
176
+ [
177
+ torch.zeros((1,), device=device, dtype=torch.float32),
178
+ middle,
179
+ torch.ones((1,), device=device, dtype=torch.float32),
180
+ ]
181
+ )
182
+ if schedule in {"power_low", "power_high"}:
183
+ if steps == 1:
184
+ return torch.tensor([0.0, 1.0], device=device, dtype=torch.float32)
185
+ generator = torch.Generator(device="cpu")
186
+ generator.manual_seed(int(seed))
187
+ u = torch.rand((steps - 1,), generator=generator, dtype=torch.float32)
188
+ exponent = max(float(power), 1e-8)
189
+ if schedule == "power_low":
190
+ middle = u.pow(exponent)
191
+ else:
192
+ middle = 1.0 - (1.0 - u).pow(exponent)
193
+ middle = middle.sort().values.to(device)
194
+ return torch.cat(
195
+ [
196
+ torch.zeros((1,), device=device, dtype=torch.float32),
197
+ middle,
198
+ torch.ones((1,), device=device, dtype=torch.float32),
199
+ ]
200
+ )
201
+ raise ValueError(f"unknown time schedule: {schedule}")
202
+
203
+
204
+ def clamp_first_position(probs: torch.Tensor, first_ids: torch.Tensor | None) -> torch.Tensor:
205
+ if first_ids is None:
206
+ return probs
207
+ probs = probs.clone()
208
+ probs[:, 0, :].zero_()
209
+ probs[:, 0, :].scatter_(1, first_ids[:, None], 1.0)
210
+ return probs
211
+
212
+
213
+ def final_decode_ids(
214
+ probs: torch.Tensor,
215
+ *,
216
+ mode: str,
217
+ temp: float,
218
+ top_k: int,
219
+ top_p: float,
220
+ eps: float,
221
+ ) -> torch.Tensor:
222
+ if mode == "argmax":
223
+ return probs.argmax(dim=-1)
224
+ if mode != "sample":
225
+ raise ValueError(mode)
226
+ logits = probs.clamp_min(eps).log() / max(float(temp), eps)
227
+ if top_k > 0 and top_k < logits.size(-1):
228
+ kth = logits.topk(top_k, dim=-1).values[..., -1, None]
229
+ logits = logits.masked_fill(logits < kth, -torch.inf)
230
+ if 0.0 < top_p < 1.0:
231
+ sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
232
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
233
+ remove = sorted_probs.cumsum(dim=-1) > float(top_p)
234
+ remove[..., 0] = False
235
+ sorted_logits = sorted_logits.masked_fill(remove, -torch.inf)
236
+ filtered = torch.full_like(logits, -torch.inf)
237
+ logits = filtered.scatter(-1, sorted_idx, sorted_logits)
238
+ sample_probs = F.softmax(logits, dim=-1)
239
+ flat = sample_probs.reshape(-1, sample_probs.size(-1))
240
+ return torch.multinomial(flat, num_samples=1).view(probs.shape[:-1])
241
+
242
+
243
+ def onehot_from_ids(ids: torch.Tensor, vocab_size: int, dtype: torch.dtype) -> torch.Tensor:
244
+ return F.one_hot(ids.clamp_min(0), vocab_size).to(dtype=dtype)
245
+
246
+
247
+ def apply_committed_state(
248
+ probs: torch.Tensor,
249
+ committed: torch.Tensor,
250
+ committed_ids: torch.Tensor,
251
+ eps: float,
252
+ ) -> torch.Tensor:
253
+ if not bool(committed.any()):
254
+ return probs
255
+ fixed = onehot_from_ids(committed_ids, probs.size(-1), probs.dtype)
256
+ out = torch.where(committed.unsqueeze(-1), fixed, probs)
257
+ out = out.clamp_min(eps)
258
+ return out / out.sum(dim=-1, keepdim=True).clamp_min(eps)
259
+
260
+
261
+ def commit_target_ratio(
262
+ progress: float,
263
+ *,
264
+ start: float,
265
+ min_ratio: float,
266
+ max_ratio: float,
267
+ power: float,
268
+ ) -> float:
269
+ start = min(max(float(start), 0.0), 0.999999)
270
+ if progress < start:
271
+ return 0.0
272
+ rel = (min(max(float(progress), start), 1.0) - start) / max(1.0 - start, 1e-8)
273
+ ratio = float(min_ratio) + (float(max_ratio) - float(min_ratio)) * (rel ** max(float(power), 1e-8))
274
+ return min(max(ratio, 0.0), 1.0)
275
+
276
+
277
+ def select_commit_positions(
278
+ endpoint: torch.Tensor,
279
+ committed: torch.Tensor,
280
+ *,
281
+ mode: str,
282
+ step: int,
283
+ steps: int,
284
+ progress: float,
285
+ threshold: float,
286
+ margin_threshold: float,
287
+ start: float,
288
+ min_ratio: float,
289
+ max_ratio: float,
290
+ power: float,
291
+ freq_max_frac: float,
292
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
293
+ top2 = endpoint.topk(k=2, dim=-1)
294
+ conf = top2.values[..., 0]
295
+ margin = top2.values[..., 0] - top2.values[..., 1]
296
+ pred_ids = top2.indices[..., 0]
297
+ candidates = ~committed
298
+ final_step = step == steps - 1
299
+ if mode == "off":
300
+ new_commit = torch.zeros_like(committed)
301
+ return new_commit, pred_ids, {"target_ratio": 0.0, "mean_conf_new": 0.0, "mean_margin_new": 0.0}
302
+ if final_step:
303
+ new_commit = candidates
304
+ return new_commit, pred_ids, {
305
+ "target_ratio": 1.0,
306
+ "mean_conf_new": float(conf[new_commit].mean().detach().cpu()) if bool(new_commit.any()) else 0.0,
307
+ "mean_margin_new": float(margin[new_commit].mean().detach().cpu()) if bool(new_commit.any()) else 0.0,
308
+ }
309
+ if progress < start:
310
+ new_commit = torch.zeros_like(committed)
311
+ return new_commit, pred_ids, {"target_ratio": 0.0, "mean_conf_new": 0.0, "mean_margin_new": 0.0}
312
+
313
+ if mode == "threshold":
314
+ new_commit = candidates & (conf >= float(threshold)) & (margin >= float(margin_threshold))
315
+ elif mode in {"ratio", "ratio_margin_freq"}:
316
+ target_ratio = commit_target_ratio(
317
+ progress,
318
+ start=start,
319
+ min_ratio=min_ratio,
320
+ max_ratio=max_ratio,
321
+ power=power,
322
+ )
323
+ bs, length = committed.shape
324
+ new_commit = torch.zeros_like(committed)
325
+ score = conf if mode == "ratio" else margin
326
+ for b in range(bs):
327
+ already = int(committed[b].sum().detach().cpu())
328
+ target_count = min(length, int(math.ceil(float(length) * target_ratio)))
329
+ need = max(0, target_count - already)
330
+ if need <= 0:
331
+ continue
332
+ cand_idx = torch.nonzero(candidates[b], as_tuple=False).flatten()
333
+ if cand_idx.numel() == 0:
334
+ continue
335
+ valid = cand_idx[(conf[b, cand_idx] >= float(threshold)) & (margin[b, cand_idx] >= float(margin_threshold))]
336
+ if valid.numel() == 0:
337
+ continue
338
+ order = torch.argsort(score[b, valid], descending=True)
339
+ valid = valid[order]
340
+ if mode == "ratio":
341
+ chosen = valid[:need]
342
+ else:
343
+ chosen_list: list[int] = []
344
+ max_per_token = max(1, int(math.floor(float(length) * float(freq_max_frac))))
345
+ counts: dict[int, int] = {}
346
+ if bool(committed[b].any()):
347
+ ids = pred_ids[b, committed[b]].detach().cpu().tolist()
348
+ for tok in ids:
349
+ counts[int(tok)] = counts.get(int(tok), 0) + 1
350
+ for idx_t in valid.detach().cpu().tolist():
351
+ tok = int(pred_ids[b, idx_t].detach().cpu())
352
+ if counts.get(tok, 0) >= max_per_token:
353
+ continue
354
+ chosen_list.append(int(idx_t))
355
+ counts[tok] = counts.get(tok, 0) + 1
356
+ if len(chosen_list) >= need:
357
+ break
358
+ if chosen_list:
359
+ chosen = torch.tensor(chosen_list, dtype=torch.long, device=committed.device)
360
+ else:
361
+ chosen = valid[:0]
362
+ if chosen.numel() > 0:
363
+ new_commit[b, chosen] = True
364
+ else:
365
+ raise ValueError(mode)
366
+
367
+ return new_commit, pred_ids, {
368
+ "target_ratio": (
369
+ 0.0
370
+ if mode == "threshold"
371
+ else commit_target_ratio(progress, start=start, min_ratio=min_ratio, max_ratio=max_ratio, power=power)
372
+ ),
373
+ "mean_conf_new": float(conf[new_commit].mean().detach().cpu()) if bool(new_commit.any()) else 0.0,
374
+ "mean_margin_new": float(margin[new_commit].mean().detach().cpu()) if bool(new_commit.any()) else 0.0,
375
+ }
376
+
377
+
378
+ def soften_endpoint_with_prior(
379
+ endpoint: torch.Tensor,
380
+ t: float,
381
+ *,
382
+ mode: str,
383
+ power: float,
384
+ min_conf: float,
385
+ max_conf: float,
386
+ eps: float,
387
+ ) -> tuple[torch.Tensor, float]:
388
+ if mode == "none":
389
+ return endpoint, 1.0
390
+ if mode != "uniform":
391
+ raise ValueError(mode)
392
+ alpha = float(min_conf) + (float(max_conf) - float(min_conf)) * (float(t) ** float(power))
393
+ alpha = max(0.0, min(1.0, alpha))
394
+ prior = 1.0 / float(endpoint.shape[-1])
395
+ softened = alpha * endpoint + (1.0 - alpha) * prior
396
+ softened = softened.clamp_min(eps)
397
+ softened = softened / softened.sum(dim=-1, keepdim=True).clamp_min(eps)
398
+ return softened, alpha
399
+
400
+
401
+ def linear_soft_target_confidence(
402
+ t: float,
403
+ *,
404
+ power: float,
405
+ min_conf: float,
406
+ max_conf: float,
407
+ ) -> float:
408
+ alpha = float(min_conf) + (float(max_conf) - float(min_conf)) * (float(t) ** float(power))
409
+ return max(0.0, min(1.0, alpha))
410
+
411
+
412
+ def decode_linear_soft_target_endpoint(
413
+ posterior: torch.Tensor,
414
+ t: float,
415
+ *,
416
+ mode: str,
417
+ power: float,
418
+ min_conf: float,
419
+ max_conf: float,
420
+ debias_start: float,
421
+ eps: float,
422
+ ) -> tuple[torch.Tensor, float]:
423
+ """Interpret logits trained with q_t=(1-alpha)Uniform+alpha*onehot.
424
+
425
+ posterior mode uses q_hat directly. Debias modes invert the linear target
426
+ only when alpha is large enough; early inversion amplifies noise.
427
+ """
428
+ alpha = linear_soft_target_confidence(t, power=power, min_conf=min_conf, max_conf=max_conf)
429
+ if mode == "posterior" or (mode == "late_debias" and float(t) < float(debias_start)):
430
+ return posterior, alpha
431
+ if mode not in {"debias", "late_debias"}:
432
+ raise ValueError(mode)
433
+ prior = 1.0 / float(posterior.shape[-1])
434
+ denom = max(alpha, eps)
435
+ endpoint = (posterior - (1.0 - alpha) * prior) / denom
436
+ endpoint = endpoint.clamp_min(eps)
437
+ endpoint = endpoint / endpoint.sum(dim=-1, keepdim=True).clamp_min(eps)
438
+ return endpoint, alpha
439
+
440
+
441
+ @torch.inference_mode()
442
+ def decode(
443
+ model: EndpointPredictor,
444
+ tokenizer: BpeTextTokenizer,
445
+ *,
446
+ max_len: int,
447
+ n_samples: int,
448
+ batch_size: int,
449
+ steps: int,
450
+ seed: int,
451
+ device: torch.device,
452
+ decode_rule: str,
453
+ support_power: float,
454
+ semantic_power: float,
455
+ early_temp: float,
456
+ late_temp: float,
457
+ temp_end: float,
458
+ temp_power: float,
459
+ hybrid_switch: float,
460
+ tail_temp: float,
461
+ c_min: float,
462
+ c_max: float,
463
+ anchor_mode: str,
464
+ model_t_mode: str,
465
+ time_schedule: str,
466
+ time_logit_mean: float,
467
+ time_logit_std: float,
468
+ time_power: float,
469
+ input_noise_scale: float,
470
+ input_noise_until: float,
471
+ input_noise_dirichlet_concentration: float,
472
+ endpoint_softening: str,
473
+ endpoint_soft_power: float,
474
+ endpoint_soft_min_conf: float,
475
+ endpoint_soft_max_conf: float,
476
+ soft_target_decode_mode: str,
477
+ soft_target_power: float,
478
+ soft_target_min_conf: float,
479
+ soft_target_max_conf: float,
480
+ soft_target_debias_start: float,
481
+ final_from: str,
482
+ final_decode: str,
483
+ final_sample_temp: float,
484
+ final_top_k: int,
485
+ final_top_p: float,
486
+ commit_mode: str,
487
+ commit_conf_threshold: float,
488
+ commit_margin_threshold: float,
489
+ commit_start: float,
490
+ commit_min_ratio: float,
491
+ commit_max_ratio: float,
492
+ commit_power: float,
493
+ commit_freq_max_frac: float,
494
+ eps: float,
495
+ fixed_first_token_id: int | None,
496
+ fixed_first_initial_argmax: bool,
497
+ ) -> tuple[list[list[int]], list[str], list[dict[str, object]]]:
498
+ torch.manual_seed(seed)
499
+ time_grid = make_time_grid(
500
+ steps,
501
+ schedule=time_schedule,
502
+ logit_mean=time_logit_mean,
503
+ logit_std=time_logit_std,
504
+ power=time_power,
505
+ seed=seed,
506
+ device=device,
507
+ )
508
+ all_ids: list[list[int]] = []
509
+ all_texts: list[str] = []
510
+ traces: list[dict[str, object]] = []
511
+ remaining = n_samples
512
+ while remaining > 0:
513
+ bs = min(batch_size, remaining)
514
+ probs = sample_noise_simplex(
515
+ (bs, max_len),
516
+ tokenizer.vocab_size,
517
+ device,
518
+ eps,
519
+ noise_mode="dirichlet",
520
+ target_prob=1.0,
521
+ noise_sigma=-1.0,
522
+ dirichlet_concentration=1.0,
523
+ )
524
+ fixed_first_ids: torch.Tensor | None = None
525
+ if fixed_first_initial_argmax:
526
+ fixed_first_ids = probs[:, 0, :].argmax(dim=-1)
527
+ elif fixed_first_token_id is not None:
528
+ fixed_first_ids = torch.full((bs,), int(fixed_first_token_id), dtype=torch.long, device=device)
529
+ probs = clamp_first_position(probs, fixed_first_ids)
530
+ committed = torch.zeros((bs, max_len), dtype=torch.bool, device=device)
531
+ committed_ids = torch.zeros((bs, max_len), dtype=torch.long, device=device)
532
+ if fixed_first_ids is not None:
533
+ committed[:, 0] = True
534
+ committed_ids[:, 0] = fixed_first_ids
535
+ attn = torch.ones((bs, max_len), dtype=torch.bool, device=device)
536
+ last_endpoint = probs
537
+ for step in range(steps):
538
+ progress = float(time_grid[step].item())
539
+ next_progress = float(time_grid[step + 1].item())
540
+ dt = max(next_progress - progress, 0.0)
541
+ if model_t_mode in {"pre", "flow"}:
542
+ t = torch.full((bs,), float(progress), dtype=torch.float32, device=device)
543
+ elif model_t_mode == "post":
544
+ t = torch.full((bs,), float(next_progress), dtype=torch.float32, device=device)
545
+ else:
546
+ t = model_time_for_step(model_t_mode, step, steps, bs, device, dtype=torch.float32)
547
+ temp = temperature(step, steps, early_temp, late_temp, temp_end, temp_power)
548
+ if tail_temp > 0 and progress >= hybrid_switch:
549
+ temp = tail_temp
550
+ model_probs = probs
551
+ if input_noise_scale > 0.0 and progress < input_noise_until:
552
+ fresh_noise = sample_noise_simplex(
553
+ (bs, max_len),
554
+ tokenizer.vocab_size,
555
+ device,
556
+ eps,
557
+ noise_mode="dirichlet",
558
+ target_prob=1.0,
559
+ noise_sigma=-1.0,
560
+ dirichlet_concentration=input_noise_dirichlet_concentration,
561
+ )
562
+ noisy = progress * probs + (1.0 - progress) * float(input_noise_scale) * fresh_noise
563
+ model_probs = noisy.clamp_min(eps)
564
+ model_probs = model_probs / model_probs.sum(dim=-1, keepdim=True).clamp_min(eps)
565
+ logits = model(state_for_model(model, model_probs, eps), t, attn).float()
566
+ raw_endpoint = F.softmax(logits / temp, dim=-1)
567
+ if soft_target_decode_mode == "off":
568
+ endpoint, endpoint_alpha = soften_endpoint_with_prior(
569
+ raw_endpoint,
570
+ next_progress,
571
+ mode=endpoint_softening,
572
+ power=endpoint_soft_power,
573
+ min_conf=endpoint_soft_min_conf,
574
+ max_conf=endpoint_soft_max_conf,
575
+ eps=eps,
576
+ )
577
+ else:
578
+ endpoint, endpoint_alpha = decode_linear_soft_target_endpoint(
579
+ raw_endpoint,
580
+ next_progress,
581
+ mode=soft_target_decode_mode,
582
+ power=soft_target_power,
583
+ min_conf=soft_target_min_conf,
584
+ max_conf=soft_target_max_conf,
585
+ debias_start=soft_target_debias_start,
586
+ eps=eps,
587
+ )
588
+ new_commit, pred_ids, commit_stats = select_commit_positions(
589
+ endpoint,
590
+ committed,
591
+ mode=commit_mode,
592
+ step=step,
593
+ steps=steps,
594
+ progress=next_progress,
595
+ threshold=commit_conf_threshold,
596
+ margin_threshold=commit_margin_threshold,
597
+ start=commit_start,
598
+ min_ratio=commit_min_ratio,
599
+ max_ratio=commit_max_ratio,
600
+ power=commit_power,
601
+ freq_max_frac=commit_freq_max_frac,
602
+ )
603
+ if bool(new_commit.any()):
604
+ committed_ids = torch.where(new_commit, pred_ids, committed_ids)
605
+ committed = committed | new_commit
606
+ last_endpoint = endpoint
607
+ support_t = next_progress ** support_power
608
+ if decode_rule == "dirichlet_resample":
609
+ probs = dirichlet_resample(dirichlet_path_mean(endpoint, support_t, eps), support_t, c_min, c_max, eps)
610
+ elif decode_rule == "dual_line_resample":
611
+ semantic_t = next_progress ** semantic_power
612
+ anchor = current_anchor(probs, anchor_mode, eps)
613
+ forward_endpoint = (1.0 - semantic_t) * anchor + semantic_t * endpoint
614
+ forward_endpoint = forward_endpoint.clamp_min(eps)
615
+ forward_endpoint = forward_endpoint / forward_endpoint.sum(dim=-1, keepdim=True).clamp_min(eps)
616
+ probs = dirichlet_resample(dirichlet_path_mean(forward_endpoint, support_t, eps), support_t, c_min, c_max, eps)
617
+ elif decode_rule == "dual_replace_resample":
618
+ semantic_t = next_progress ** semantic_power
619
+ anchor = current_anchor(probs, anchor_mode, eps)
620
+ replace = torch.rand((bs, max_len, 1), device=device) < semantic_t
621
+ forward_endpoint = torch.where(replace, endpoint, anchor)
622
+ forward_endpoint = forward_endpoint.clamp_min(eps)
623
+ forward_endpoint = forward_endpoint / forward_endpoint.sum(dim=-1, keepdim=True).clamp_min(eps)
624
+ probs = dirichlet_resample(dirichlet_path_mean(forward_endpoint, support_t, eps), support_t, c_min, c_max, eps)
625
+ elif decode_rule in {"log_dual_resample", "sqrt_dual_resample", "fisher_dual_resample"}:
626
+ geometry = decode_rule.split("_", 1)[0]
627
+ semantic_t = next_progress ** semantic_power
628
+ anchor = current_anchor(probs, anchor_mode, eps)
629
+ forward_endpoint = simplex_mix(anchor, endpoint, semantic_t, eps, geometry)
630
+ probs = dirichlet_resample(dirichlet_path_mean(forward_endpoint, support_t, eps), support_t, c_min, c_max, eps)
631
+ elif decode_rule == "flowmap":
632
+ gamma = min(dt / max(1.0 - progress, eps), 1.0)
633
+ probs = probs + gamma * (endpoint - probs)
634
+ probs = probs.clamp_min(eps)
635
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(eps)
636
+ elif decode_rule in {"log_geodesic", "sqrt_geodesic", "fisher_geodesic"}:
637
+ geometry = decode_rule.split("_", 1)[0]
638
+ gamma = min(dt / max(1.0 - progress, eps), 1.0)
639
+ probs = simplex_mix(probs, endpoint, gamma, eps, geometry)
640
+ elif decode_rule in {"hybrid_log_flowmap", "hybrid_log_dirres", "hybrid_log_logflow"}:
641
+ if progress < hybrid_switch:
642
+ local = min(1.0, next_progress / max(hybrid_switch, 1e-8))
643
+ semantic_t = local ** semantic_power
644
+ anchor = current_anchor(probs, anchor_mode, eps)
645
+ forward_endpoint = simplex_mix(anchor, endpoint, semantic_t, eps, "log")
646
+ probs = dirichlet_resample(dirichlet_path_mean(forward_endpoint, support_t, eps), support_t, c_min, c_max, eps)
647
+ elif decode_rule == "hybrid_log_flowmap":
648
+ gamma = min(dt / max(1.0 - progress, eps), 1.0)
649
+ probs = simplex_mix(probs, endpoint, gamma, eps, "linear")
650
+ elif decode_rule == "hybrid_log_logflow":
651
+ gamma = min(dt / max(1.0 - progress, eps), 1.0)
652
+ probs = simplex_mix(probs, endpoint, gamma, eps, "log")
653
+ else:
654
+ probs = dirichlet_resample(dirichlet_path_mean(endpoint, support_t, eps), support_t, c_min, c_max, eps)
655
+ else:
656
+ raise ValueError(decode_rule)
657
+ probs = clamp_first_position(probs, fixed_first_ids)
658
+ probs = apply_committed_state(probs, committed, committed_ids, eps)
659
+ if step in {0, 1, 3, 7, 15, 31, 63, steps - 1}:
660
+ ids0 = probs.argmax(dim=-1)[0].detach().cpu().tolist()
661
+ raw_maxprob = raw_endpoint[0].amax(dim=-1).mean().detach().item()
662
+ soft_maxprob = endpoint[0].amax(dim=-1).mean().detach().item()
663
+ commit_frac = committed[0].float().mean().detach().item()
664
+ new_commit_frac = new_commit[0].float().mean().detach().item()
665
+ traces.append({
666
+ "step": step + 1,
667
+ "progress": progress,
668
+ "next_progress": next_progress,
669
+ "dt": dt,
670
+ "temperature": temp,
671
+ "endpoint_alpha": endpoint_alpha,
672
+ "raw_endpoint_mean_maxprob": raw_maxprob,
673
+ "effective_endpoint_mean_maxprob": soft_maxprob,
674
+ "commit_mode": commit_mode,
675
+ "commit_frac": commit_frac,
676
+ "new_commit_frac": new_commit_frac,
677
+ **commit_stats,
678
+ "sample0_text": tokenizer.decode(ids0, stop_at_eos=False, skip_special_tokens=False)[:480],
679
+ })
680
+ if final_from == "state":
681
+ final = probs
682
+ elif final_from == "endpoint":
683
+ final = last_endpoint
684
+ elif final_from == "blend":
685
+ final = 0.5 * probs + 0.5 * last_endpoint
686
+ else:
687
+ raise ValueError(final_from)
688
+ final = clamp_first_position(final, fixed_first_ids)
689
+ ids_tensor = final_decode_ids(
690
+ final,
691
+ mode=final_decode,
692
+ temp=final_sample_temp,
693
+ top_k=final_top_k,
694
+ top_p=final_top_p,
695
+ eps=eps,
696
+ )
697
+ ids = ids_tensor.detach().cpu().tolist()
698
+ all_ids.extend(ids)
699
+ all_texts.extend(tokenizer.decode(row, stop_at_eos=False, skip_special_tokens=False) for row in ids)
700
+ remaining -= bs
701
+ print(f"[decode] max_len={max_len} generated={n_samples-remaining}/{n_samples}", flush=True)
702
+ return all_ids, all_texts, traces
703
+
704
+
705
+ def score_with_gpt2(texts: list[str], scorer_path: str, batch_size: int, max_length: int, device: torch.device) -> dict[str, float]:
706
+ scorer_tok = AutoTokenizer.from_pretrained(scorer_path, local_files_only=True)
707
+ if scorer_tok.pad_token is None:
708
+ scorer_tok.pad_token = scorer_tok.eos_token
709
+ scorer = AutoModelForCausalLM.from_pretrained(scorer_path, local_files_only=True).to(device).eval()
710
+ total_nll = 0.0
711
+ total_tokens = 0
712
+ for start in range(0, len(texts), batch_size):
713
+ batch = texts[start:start + batch_size]
714
+ enc = scorer_tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length).to(device)
715
+ input_ids = enc["input_ids"]
716
+ attn = enc["attention_mask"]
717
+ if input_ids.size(1) < 2:
718
+ continue
719
+ logits = scorer(input_ids=input_ids, attention_mask=attn).logits.transpose(-1, -2)
720
+ nll = F.cross_entropy(logits[..., :-1].float(), input_ids[..., 1:], reduction="none")
721
+ mask = attn[..., 1:].bool()
722
+ total_nll += float(nll[mask].sum().detach().cpu())
723
+ total_tokens += int(mask.sum().detach().cpu())
724
+ del scorer
725
+ if device.type == "cuda":
726
+ torch.cuda.empty_cache()
727
+ mean = total_nll / max(total_tokens, 1)
728
+ return {"gen_ppl": math.exp(min(20.0, mean)), "gen_nll": mean, "gen_tokens": total_tokens}
729
+
730
+
731
+ def main() -> None:
732
+ ap = argparse.ArgumentParser()
733
+ ap.add_argument("--checkpoint", required=True)
734
+ ap.add_argument("--tokenizer_path", required=True)
735
+ ap.add_argument("--out_dir", required=True)
736
+ ap.add_argument("--max_lens", default="128,1024")
737
+ ap.add_argument("--n_samples", type=int, default=16)
738
+ ap.add_argument("--batch_size", type=int, default=2)
739
+ ap.add_argument("--steps", type=int, default=128)
740
+ ap.add_argument(
741
+ "--decode_rule",
742
+ choices=[
743
+ "dual_line_resample",
744
+ "dual_replace_resample",
745
+ "dirichlet_resample",
746
+ "flowmap",
747
+ "log_dual_resample",
748
+ "sqrt_dual_resample",
749
+ "fisher_dual_resample",
750
+ "log_geodesic",
751
+ "sqrt_geodesic",
752
+ "fisher_geodesic",
753
+ "hybrid_log_flowmap",
754
+ "hybrid_log_dirres",
755
+ "hybrid_log_logflow",
756
+ ],
757
+ default="dual_line_resample",
758
+ )
759
+ ap.add_argument("--pos_extend", choices=["repeat", "interpolate"], default="repeat")
760
+ ap.add_argument("--support_power", type=float, default=1.0)
761
+ ap.add_argument("--semantic_power", type=float, default=1.5)
762
+ ap.add_argument("--early_temp", type=float, default=2.8)
763
+ ap.add_argument("--late_temp", type=float, default=1.45)
764
+ ap.add_argument("--temp_end", type=float, default=0.55)
765
+ ap.add_argument("--temp_power", type=float, default=1.5)
766
+ ap.add_argument("--hybrid_switch", type=float, default=0.5)
767
+ ap.add_argument("--tail_temp", type=float, default=-1.0)
768
+ ap.add_argument("--c_min", type=float, default=1.0)
769
+ ap.add_argument("--c_max", type=float, default=1024.0)
770
+ ap.add_argument("--anchor_mode", choices=["state", "onehot", "sqrt_state"], default="state")
771
+ ap.add_argument(
772
+ "--model_t_mode",
773
+ choices=["pre", "post", "flow", "linear", "const0", "const05", "const1", "random"],
774
+ default="flow",
775
+ )
776
+ ap.add_argument("--time_schedule", choices=["uniform", "logit_normal", "power_low", "power_high"], default="uniform")
777
+ ap.add_argument("--time_logit_mean", type=float, default=-1.5)
778
+ ap.add_argument("--time_logit_std", type=float, default=0.8)
779
+ ap.add_argument("--time_power", type=float, default=2.0)
780
+ ap.add_argument("--input_noise_scale", type=float, default=0.0)
781
+ ap.add_argument("--input_noise_until", type=float, default=1.0)
782
+ ap.add_argument("--input_noise_dirichlet_concentration", type=float, default=1.0)
783
+ ap.add_argument("--endpoint_softening", choices=["none", "uniform"], default="none")
784
+ ap.add_argument("--endpoint_soft_power", type=float, default=2.0)
785
+ ap.add_argument("--endpoint_soft_min_conf", type=float, default=0.0)
786
+ ap.add_argument("--endpoint_soft_max_conf", type=float, default=1.0)
787
+ ap.add_argument("--soft_target_decode_mode", choices=["off", "posterior", "debias", "late_debias"], default="off")
788
+ ap.add_argument("--soft_target_power", type=float, default=1.0)
789
+ ap.add_argument("--soft_target_min_conf", type=float, default=0.0)
790
+ ap.add_argument("--soft_target_max_conf", type=float, default=1.0)
791
+ ap.add_argument("--soft_target_debias_start", type=float, default=0.7)
792
+ ap.add_argument("--final_from", choices=["state", "endpoint", "blend"], default="blend")
793
+ ap.add_argument("--final_decode", choices=["argmax", "sample"], default="argmax")
794
+ ap.add_argument("--final_sample_temp", type=float, default=1.0)
795
+ ap.add_argument("--final_top_k", type=int, default=0)
796
+ ap.add_argument("--final_top_p", type=float, default=1.0)
797
+ ap.add_argument("--commit_mode", choices=["off", "threshold", "ratio", "ratio_margin_freq"], default="off")
798
+ ap.add_argument("--commit_conf_threshold", type=float, default=0.0)
799
+ ap.add_argument("--commit_margin_threshold", type=float, default=0.0)
800
+ ap.add_argument("--commit_start", type=float, default=0.0)
801
+ ap.add_argument("--commit_min_ratio", type=float, default=0.0)
802
+ ap.add_argument("--commit_max_ratio", type=float, default=1.0)
803
+ ap.add_argument("--commit_power", type=float, default=2.0)
804
+ ap.add_argument("--commit_freq_max_frac", type=float, default=0.08)
805
+ ap.add_argument("--fixed_first_token_id", type=int, default=-1)
806
+ ap.add_argument("--fixed_first_token_text", default="")
807
+ ap.add_argument("--fixed_first_initial_argmax", action="store_true")
808
+ ap.add_argument("--scorer", default="/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-large-standard")
809
+ ap.add_argument("--score", action="store_true")
810
+ ap.add_argument("--use_ema", action="store_true", help="Use ema_model from checkpoint if present.")
811
+ ap.add_argument("--seed", type=int, default=20260514)
812
+ args = ap.parse_args()
813
+
814
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
815
+ tok = BpeTextTokenizer.from_file(args.tokenizer_path)
816
+ ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False, mmap=True)
817
+ if args.use_ema and "ema_model" in ckpt:
818
+ ckpt = dict(ckpt)
819
+ ckpt["model"] = ckpt["ema_model"]
820
+ out_dir = Path(args.out_dir)
821
+ out_dir.mkdir(parents=True, exist_ok=True)
822
+ fixed_first_token_id: int | None = None
823
+ if args.fixed_first_token_text:
824
+ encoded = tok.encode(args.fixed_first_token_text, add_eos=False, add_special_tokens=False)
825
+ if not encoded:
826
+ raise ValueError(f"fixed_first_token_text encoded to no tokens: {args.fixed_first_token_text!r}")
827
+ fixed_first_token_id = int(encoded[0])
828
+ elif args.fixed_first_token_id >= 0:
829
+ fixed_first_token_id = int(args.fixed_first_token_id)
830
+ summary = []
831
+ for max_len_s in args.max_lens.split(","):
832
+ max_len = int(max_len_s)
833
+ model = build_model(ckpt, tok, max_len, device, args.pos_extend)
834
+ ids, texts, traces = decode(
835
+ model,
836
+ tok,
837
+ max_len=max_len,
838
+ n_samples=args.n_samples,
839
+ batch_size=args.batch_size,
840
+ steps=args.steps,
841
+ seed=args.seed + max_len,
842
+ device=device,
843
+ decode_rule=args.decode_rule,
844
+ support_power=args.support_power,
845
+ semantic_power=args.semantic_power,
846
+ early_temp=args.early_temp,
847
+ late_temp=args.late_temp,
848
+ temp_end=args.temp_end,
849
+ temp_power=args.temp_power,
850
+ hybrid_switch=args.hybrid_switch,
851
+ tail_temp=args.tail_temp,
852
+ c_min=args.c_min,
853
+ c_max=args.c_max,
854
+ anchor_mode=args.anchor_mode,
855
+ model_t_mode=args.model_t_mode,
856
+ time_schedule=args.time_schedule,
857
+ time_logit_mean=args.time_logit_mean,
858
+ time_logit_std=args.time_logit_std,
859
+ time_power=args.time_power,
860
+ input_noise_scale=args.input_noise_scale,
861
+ input_noise_until=args.input_noise_until,
862
+ input_noise_dirichlet_concentration=args.input_noise_dirichlet_concentration,
863
+ endpoint_softening=args.endpoint_softening,
864
+ endpoint_soft_power=args.endpoint_soft_power,
865
+ endpoint_soft_min_conf=args.endpoint_soft_min_conf,
866
+ endpoint_soft_max_conf=args.endpoint_soft_max_conf,
867
+ soft_target_decode_mode=args.soft_target_decode_mode,
868
+ soft_target_power=args.soft_target_power,
869
+ soft_target_min_conf=args.soft_target_min_conf,
870
+ soft_target_max_conf=args.soft_target_max_conf,
871
+ soft_target_debias_start=args.soft_target_debias_start,
872
+ final_from=args.final_from,
873
+ final_decode=args.final_decode,
874
+ final_sample_temp=args.final_sample_temp,
875
+ final_top_k=args.final_top_k,
876
+ final_top_p=args.final_top_p,
877
+ commit_mode=args.commit_mode,
878
+ commit_conf_threshold=args.commit_conf_threshold,
879
+ commit_margin_threshold=args.commit_margin_threshold,
880
+ commit_start=args.commit_start,
881
+ commit_min_ratio=args.commit_min_ratio,
882
+ commit_max_ratio=args.commit_max_ratio,
883
+ commit_power=args.commit_power,
884
+ commit_freq_max_frac=args.commit_freq_max_frac,
885
+ eps=1e-8,
886
+ fixed_first_token_id=fixed_first_token_id,
887
+ fixed_first_initial_argmax=args.fixed_first_initial_argmax,
888
+ )
889
+ filt_result = filter_generated_texts(texts, min_chars=0, normalize_whitespace=True, drop_empty=False)
890
+ filt = filt_result[0] if isinstance(filt_result, tuple) else filt_result
891
+ diversity_result = summarize_token_diversity(ids)
892
+ diversity = asdict(diversity_result) if is_dataclass(diversity_result) else dict(diversity_result)
893
+ rec = {
894
+ "checkpoint": args.checkpoint,
895
+ "ckpt_step": int(ckpt.get("step", -1)),
896
+ "max_len": max_len,
897
+ "decode_rule": args.decode_rule,
898
+ "support_power": args.support_power,
899
+ "semantic_power": args.semantic_power,
900
+ "steps": args.steps,
901
+ "c_min": args.c_min,
902
+ "c_max": args.c_max,
903
+ "anchor_mode": args.anchor_mode,
904
+ "model_t_mode": args.model_t_mode,
905
+ "time_schedule": args.time_schedule,
906
+ "time_logit_mean": args.time_logit_mean,
907
+ "time_logit_std": args.time_logit_std,
908
+ "time_power": args.time_power,
909
+ "input_noise_scale": args.input_noise_scale,
910
+ "input_noise_until": args.input_noise_until,
911
+ "input_noise_dirichlet_concentration": args.input_noise_dirichlet_concentration,
912
+ "endpoint_softening": args.endpoint_softening,
913
+ "endpoint_soft_power": args.endpoint_soft_power,
914
+ "endpoint_soft_min_conf": args.endpoint_soft_min_conf,
915
+ "endpoint_soft_max_conf": args.endpoint_soft_max_conf,
916
+ "soft_target_decode_mode": args.soft_target_decode_mode,
917
+ "soft_target_power": args.soft_target_power,
918
+ "soft_target_min_conf": args.soft_target_min_conf,
919
+ "soft_target_max_conf": args.soft_target_max_conf,
920
+ "soft_target_debias_start": args.soft_target_debias_start,
921
+ "final_from": args.final_from,
922
+ "final_decode": args.final_decode,
923
+ "final_sample_temp": args.final_sample_temp,
924
+ "final_top_k": args.final_top_k,
925
+ "final_top_p": args.final_top_p,
926
+ "commit_mode": args.commit_mode,
927
+ "commit_conf_threshold": args.commit_conf_threshold,
928
+ "commit_margin_threshold": args.commit_margin_threshold,
929
+ "commit_start": args.commit_start,
930
+ "commit_min_ratio": args.commit_min_ratio,
931
+ "commit_max_ratio": args.commit_max_ratio,
932
+ "commit_power": args.commit_power,
933
+ "commit_freq_max_frac": args.commit_freq_max_frac,
934
+ "early_temp": args.early_temp,
935
+ "late_temp": args.late_temp,
936
+ "temp_end": args.temp_end,
937
+ "temp_power": args.temp_power,
938
+ "pos_extend": args.pos_extend,
939
+ "fixed_first_token_id": fixed_first_token_id,
940
+ "fixed_first_token_text": args.fixed_first_token_text,
941
+ "fixed_first_initial_argmax": bool(args.fixed_first_initial_argmax),
942
+ "use_ema": bool(args.use_ema and "ema_model" in ckpt),
943
+ "n_samples": len(texts),
944
+ **diversity,
945
+ "texts_preview": filt[:4],
946
+ }
947
+ (out_dir / f"context{max_len}_samples.txt").write_text("\n\n---\n\n".join(filt), encoding="utf-8")
948
+ (out_dir / f"context{max_len}_trace.json").write_text(json.dumps(traces, ensure_ascii=False, indent=2), encoding="utf-8")
949
+ del model
950
+ if device.type == "cuda":
951
+ torch.cuda.empty_cache()
952
+ if args.score:
953
+ rec.update(score_with_gpt2(filt, args.scorer, batch_size=2, max_length=min(max_len, 1024), device=device))
954
+ summary.append(rec)
955
+ (out_dir / "summary.json").write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
956
+ print(json.dumps(summary, ensure_ascii=False, indent=2), flush=True)
957
+
958
+
959
+ if __name__ == "__main__":
960
+ main()
LTA_openwebtext_dualt/scripts/launch_lta_lm1b_dualtline_cmax16_8gpu_duo_small_1m.sh ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
6
+ export TOKENIZERS_PARALLELISM=false
7
+ export PYTHONUNBUFFERED=1
8
+ export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
9
+ export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
10
+
11
+ RUN_NAME="${RUN_NAME:-lta_lm1b_duo_aligned_dirichlet_true_dualtline_cmax16_flmpack_onehot_hardce_ddit_small_len128_gbs512_8gpu_1m_nw0}"
12
+ DATA_PATH="${DATA_PATH:-data/lm1b_train_parquet}"
13
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/workspace/imagenet_handoff_20260327/nlp_dts_light/assets/distilbert-base-uncased/tokenizer.json}"
14
+ TEXT_COLUMN="${TEXT_COLUMN:-}"
15
+ OPENWEBTEXT_SPLIT="${OPENWEBTEXT_SPLIT:-all}"
16
+ SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
17
+ LOG_FILE="${LOG_FILE:-logs/${RUN_NAME}.log}"
18
+
19
+ NNODES="${NNODES:-1}"
20
+ NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
21
+ NODE_RANK="${NODE_RANK:-0}"
22
+ MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
23
+ MASTER_PORT="${MASTER_PORT:-29621}"
24
+
25
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
26
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-64}"
27
+ TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
28
+ WARMUP_STEPS="${WARMUP_STEPS:-2500}"
29
+ MAX_LEN="${MAX_LEN:-128}"
30
+ WRAP_MODE="${WRAP_MODE:-stream}"
31
+ WRAP_RECORD_BUFFER_SIZE="${WRAP_RECORD_BUFFER_SIZE:-200}"
32
+ NUM_WORKERS="${NUM_WORKERS:-0}"
33
+ LOG_EVERY="${LOG_EVERY:-100}"
34
+ SAVE_EVERY="${SAVE_EVERY:-20000}"
35
+ LATEST_EVERY="${LATEST_EVERY:-1000}"
36
+ EVAL_EVERY="${EVAL_EVERY:-0}"
37
+ RESUME_PATH="${RESUME_PATH:-}"
38
+ ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
39
+ ENABLE_TORCH_COMPILE="${ENABLE_TORCH_COMPILE:-0}"
40
+ FORCE_DISABLE_TORCH_COMPILE="${FORCE_DISABLE_TORCH_COMPILE:-1}"
41
+ if [[ "${FORCE_DISABLE_TORCH_COMPILE}" == "1" ]]; then
42
+ ENABLE_TORCH_COMPILE=0
43
+ fi
44
+ if [[ "${DATA_PATH}" == *"lm1b_train_parquet"* && "${NUM_WORKERS}" != "0" ]]; then
45
+ echo "LM1B has only 9 parquet shards; forcing NUM_WORKERS=0 to avoid empty DDP dataloader shards." >&2
46
+ NUM_WORKERS=0
47
+ fi
48
+ COMPILE_ARGS=()
49
+ if [[ "${ENABLE_TORCH_COMPILE}" == "1" ]]; then
50
+ COMPILE_ARGS+=(--torch_compile --compile_mode reduce-overhead)
51
+ fi
52
+ RESUME_ARGS=()
53
+ if [[ -n "${RESUME_PATH}" ]]; then
54
+ RESUME_ARGS+=(--resume_path "${RESUME_PATH}")
55
+ fi
56
+ TEXT_COLUMN_ARGS=()
57
+ if [[ -n "${TEXT_COLUMN}" ]]; then
58
+ TEXT_COLUMN_ARGS+=(--text_column "${TEXT_COLUMN}")
59
+ fi
60
+
61
+ if [[ -f "${SAVE_DIR}/args.json" && -z "${RESUME_PATH}" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
62
+ echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
63
+ echo "Use a new RUN_NAME/SAVE_DIR, set RESUME_PATH to resume, or set ALLOW_EXISTING_SAVE_DIR=1 intentionally." >&2
64
+ exit 2
65
+ fi
66
+
67
+ mkdir -p logs runs "${SAVE_DIR}"
68
+
69
+ python -m torch.distributed.run \
70
+ --nnodes="${NNODES}" \
71
+ --nproc_per_node="${NPROC_PER_NODE}" \
72
+ --node_rank="${NODE_RANK}" \
73
+ --master_addr="${MASTER_ADDR}" \
74
+ --master_port="${MASTER_PORT}" \
75
+ train.py \
76
+ --data_path "${DATA_PATH}" \
77
+ "${TEXT_COLUMN_ARGS[@]}" \
78
+ --openwebtext_split "${OPENWEBTEXT_SPLIT}" \
79
+ --tokenizer_path "${TOKENIZER_PATH}" \
80
+ --save_dir "${SAVE_DIR}" \
81
+ --wrap \
82
+ --wrap_mode "${WRAP_MODE}" \
83
+ --wrap_record_buffer_size "${WRAP_RECORD_BUFFER_SIZE}" \
84
+ --max_len "${MAX_LEN}" \
85
+ --batch_size "${PER_GPU_BATCH_SIZE}" \
86
+ --num_workers "${NUM_WORKERS}" \
87
+ --global_batch_size "${GLOBAL_BATCH_SIZE}" \
88
+ --total_steps "${TOTAL_STEPS}" \
89
+ --log_every "${LOG_EVERY}" \
90
+ --eval_every "${EVAL_EVERY}" \
91
+ --save_every "${SAVE_EVERY}" \
92
+ --latest_every "${LATEST_EVERY}" \
93
+ --lr 3e-4 \
94
+ --weight_decay 0 \
95
+ --adam_beta1 0.9 \
96
+ --adam_beta2 0.999 \
97
+ --adam_eps 1e-8 \
98
+ --warmup_steps "${WARMUP_STEPS}" \
99
+ --lr_schedule constant_warmup \
100
+ --grad_clip 1.0 \
101
+ --seed 123 \
102
+ --d_model 768 \
103
+ --cond_dim 128 \
104
+ --n_layers 12 \
105
+ --n_heads 12 \
106
+ --dim_ff 3072 \
107
+ --dropout 0.1 \
108
+ --model_type ddit \
109
+ --state_format prob \
110
+ --bridge dirichlet \
111
+ --target_loss hard_ce \
112
+ --target_prob 1.0 \
113
+ --min_t 0.0 \
114
+ --max_t 1.0 \
115
+ --dual_t \
116
+ --corrupt_t_mode independent \
117
+ --corrupt_min_t 0.0 \
118
+ --corrupt_max_t 1.0 \
119
+ --min_mask_ratio 0.1 \
120
+ --max_mask_ratio 1.0 \
121
+ --wrong_token_replace_prob 1.0 \
122
+ --wrong_token_schedule linear_t \
123
+ --wrong_token_exp_k 1.0 \
124
+ --dirichlet_concentration_min 1.0 \
125
+ --dirichlet_concentration_max 16.0 \
126
+ --dirichlet_endpoint_mode dual_t_line \
127
+ --dirichlet_semantic_t_mode independent \
128
+ --dirichlet_semantic_t_value 0.0 \
129
+ --eps 1e-8 \
130
+ --infer_steps 128 \
131
+ --decode_damping 1.0 \
132
+ --max_gamma 1.0 \
133
+ --decode_solver flowmap \
134
+ --noise_init logistic_normal \
135
+ --bridge_noise_init logistic_normal \
136
+ --noise_sigma -1 \
137
+ "${RESUME_ARGS[@]}" \
138
+ "${COMPILE_ARGS[@]}" \
139
+ --bf16 2>&1 | tee -a "${LOG_FILE}"
LTA_openwebtext_dualt/scripts/launch_lta_owt_gamma2_8gpu.sh ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
6
+ export TOKENIZERS_PARALLELISM=false
7
+ export PYTHONUNBUFFERED=1
8
+ export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
9
+ export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
10
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
11
+ export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
12
+
13
+ RUN_NAME="${RUN_NAME:-lta_owt_gpt2cached_len1024_maskfloor_logitpow2gamma2_c1024_ddit768x12_muon_ema_gbs512_8gpu_1m_$(date +%Y%m%d_%H%M%S)}"
14
+ SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
15
+ LOG_DIR="${LOG_DIR:-logs/gamma2_8gpu}"
16
+ LOG_FILE="${LOG_FILE:-${LOG_DIR}/${RUN_NAME}.log}"
17
+ mkdir -p "${LOG_DIR}" "${SAVE_DIR}"
18
+
19
+ OWT_CACHE="${OWT_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k}"
20
+
21
+ PERF_ARGS=()
22
+ if [[ "${ALLOW_TF32:-1}" == "1" ]]; then
23
+ PERF_ARGS+=(--allow_tf32)
24
+ else
25
+ PERF_ARGS+=(--no-allow_tf32)
26
+ fi
27
+ if [[ "${ACTIVATION_CHECKPOINTING:-1}" == "1" ]]; then
28
+ PERF_ARGS+=(
29
+ --activation_checkpointing
30
+ --activation_checkpoint_interval "${ACTIVATION_CHECKPOINT_INTERVAL:-2}"
31
+ --activation_checkpoint_scope "${ACTIVATION_CHECKPOINT_SCOPE:-block}"
32
+ )
33
+ fi
34
+ if [[ "${DDP_STATIC_GRAPH:-0}" == "1" ]]; then
35
+ PERF_ARGS+=(--ddp_static_graph)
36
+ fi
37
+ if [[ "${DDP_GRADIENT_AS_BUCKET_VIEW:-1}" == "1" ]]; then
38
+ PERF_ARGS+=(--ddp_gradient_as_bucket_view)
39
+ else
40
+ PERF_ARGS+=(--no-ddp_gradient_as_bucket_view)
41
+ fi
42
+ if [[ "${BLOCKING_DATA_TRANSFER:-0}" == "1" ]]; then
43
+ PERF_ARGS+=(--blocking_data_transfer)
44
+ fi
45
+
46
+ python -m torch.distributed.run \
47
+ --nnodes="${NNODES:-1}" \
48
+ --nproc_per_node="${NPROC_PER_NODE:-8}" \
49
+ --node_rank="${NODE_RANK:-0}" \
50
+ --master_addr="${MASTER_ADDR:-127.0.0.1}" \
51
+ --master_port="${MASTER_PORT:-31995}" \
52
+ train.py \
53
+ --data_path /e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext \
54
+ --text_column text \
55
+ --openwebtext_split train_minus_100k \
56
+ --detokenizer auto \
57
+ --tokenizer_path /e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json \
58
+ --save_dir "${SAVE_DIR}" \
59
+ --wrap \
60
+ --wrap_mode stream \
61
+ --wrap_record_buffer_size 200 \
62
+ --owt_cached_chunks \
63
+ --owt_chunk_cache_dir "${OWT_CACHE}" \
64
+ --owt_chunk_cache_write_batch 4096 \
65
+ --max_len "${MAX_LEN:-1024}" \
66
+ --batch_size "${PER_GPU_BATCH_SIZE:-32}" \
67
+ --num_workers "${NUM_WORKERS:-8}" \
68
+ --dataloader_prefetch_factor "${DATALOADER_PREFETCH_FACTOR:-4}" \
69
+ --global_batch_size "${GLOBAL_BATCH_SIZE:-512}" \
70
+ --total_steps "${TOTAL_STEPS:-1000000}" \
71
+ --warmup_steps "${WARMUP_STEPS:-2000}" \
72
+ --log_every "${LOG_EVERY:-50}" \
73
+ --eval_every "${EVAL_EVERY:-0}" \
74
+ --save_every "${SAVE_EVERY:-50000}" \
75
+ --latest_every "${LATEST_EVERY:-1000}" \
76
+ --lr "${LEARNING_RATE:-0.002}" \
77
+ --lr_schedule constant_warmup \
78
+ --min_lr 0 \
79
+ --weight_decay 0.0 \
80
+ --adam_beta1 0.9 \
81
+ --adam_beta2 0.95 \
82
+ --adam_eps 1e-8 \
83
+ --optimizer muon \
84
+ --muon_momentum 0.95 \
85
+ --muon_ns_steps 5 \
86
+ --muon_update_scale 1.0 \
87
+ --ema_decay 0.9999 \
88
+ --ema_start_step 0 \
89
+ --grad_clip 1.0 \
90
+ --adamw_param_groups nanogpt \
91
+ --seed 123 \
92
+ --d_model 768 \
93
+ --cond_dim 128 \
94
+ --n_layers 12 \
95
+ --n_heads 12 \
96
+ --dim_ff 3072 \
97
+ --dropout 0.0 \
98
+ --model_type ddit \
99
+ --state_format prob \
100
+ --bridge dirichlet \
101
+ --target_loss hard_ce \
102
+ --target_prob 1.0 \
103
+ --min_t 0.0 \
104
+ --max_t 1.0 \
105
+ --dual_t \
106
+ --corrupt_t_mode same \
107
+ --corrupt_min_t 0.0 \
108
+ --corrupt_max_t 1.0 \
109
+ --min_mask_ratio 0.1 \
110
+ --max_mask_ratio 1.0 \
111
+ --mask_ratio_floor_schedule one_minus_t \
112
+ --wrong_token_replace_prob 1.0 \
113
+ --wrong_token_schedule linear_t \
114
+ --wrong_token_exp_k 1.0 \
115
+ --dirichlet_concentration_min 1.0 \
116
+ --dirichlet_concentration_max 1024 \
117
+ --dirichlet_endpoint_mode categorical_dual_t \
118
+ --dirichlet_semantic_t_mode same \
119
+ --dirichlet_semantic_t_value 0.0 \
120
+ --dirichlet_semantic_t_curve logit_power \
121
+ --dirichlet_semantic_t_power 2.0 \
122
+ --endpoint_sequence_random_prob_alpha 0.0 \
123
+ --categorical_wrong_from_full_vocab \
124
+ --simplex_bridge_sampler dirichlet \
125
+ --eps 1e-8 \
126
+ --infer_steps 1024 \
127
+ --decode_damping 1.0 \
128
+ --max_gamma 1.0 \
129
+ --decode_solver flowmap \
130
+ --noise_init logistic_normal \
131
+ --bridge_noise_init logistic_normal \
132
+ --noise_sigma -1 \
133
+ "${PERF_ARGS[@]}" \
134
+ --bf16 2>&1 | tee -a "${LOG_FILE}"
LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_adaln_adamw_wd0p1_rollin_p50_randk0_3_2node.sh ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+ set -x
4
+ set -o pipefail
5
+
6
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
7
+
8
+ # LTA's working torchrun on this cluster is /usr/local/bin/torchrun, which is
9
+ # bound to /usr/bin/python. Do not source the ConvNeXt my_env by default: that
10
+ # env currently lacks the HuggingFace tokenizers package needed by T5 tokenized
11
+ # data. Set ACTIVATE_ENV=/path/to/bin/activate only when you know it has the
12
+ # LTA dependencies installed.
13
+ if [[ -n "${ACTIVATE_ENV:-}" ]]; then
14
+ source "${ACTIVATE_ENV}"
15
+ fi
16
+
17
+ # MLP multi-node launcher variables. These can still be overridden manually.
18
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
19
+ NNODES="${NNODES:-${MLP_WORKER_NUM:-2}}"
20
+ NODE_RANK="${NODE_RANK:-${MLP_ROLE_INDEX:-0}}"
21
+ MASTER_ADDR="${MASTER_ADDR:-${MLP_WORKER_0_HOST:-127.0.0.1}}"
22
+ MASTER_PORT="${MASTER_PORT:-${MLP_WORKER_0_PORT:-32268}}"
23
+ NPROC_PER_NODE="${NPROC_PER_NODE:-$(nvidia-smi --list-gpus | wc -l)}"
24
+ TOTAL_GPUS=$(( NNODES * NPROC_PER_NODE ))
25
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-$(( GLOBAL_BATCH_SIZE / TOTAL_GPUS ))}"
26
+
27
+ if [[ "${TOTAL_GPUS}" -le 0 || "${PER_GPU_BATCH_SIZE}" -le 0 ]]; then
28
+ echo "Invalid batch/GPU config: GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE} NNODES=${NNODES} NPROC_PER_NODE=${NPROC_PER_NODE}" >&2
29
+ exit 2
30
+ fi
31
+
32
+ if [[ -z "${CUDA_VISIBLE_DEVICES:-}" ]]; then
33
+ export CUDA_VISIBLE_DEVICES="$(seq -s, 0 $((NPROC_PER_NODE - 1)))"
34
+ else
35
+ export CUDA_VISIBLE_DEVICES
36
+ fi
37
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
38
+ export TOKENIZERS_PARALLELISM=false
39
+ export PYTHONUNBUFFERED=1
40
+ export NCCL_DEBUG="${NCCL_DEBUG:-INFO}"
41
+ export TORCH_NCCL_AVOID_RECORD_STREAMS=1
42
+ export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
43
+
44
+ python - <<'PY'
45
+ import sys
46
+ import torch
47
+ import tokenizers
48
+
49
+ print(f"[launch] python={sys.executable}")
50
+ print(f"[launch] torch={torch.__version__} tokenizers={tokenizers.__version__}")
51
+ PY
52
+
53
+ export DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-t5}"
54
+ export TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
55
+
56
+ # RUN_NAME must match across nodes. In MLP jobs, prefer the job id if present;
57
+ # otherwise MLP_WORKER_0_PORT is usually stable across workers for one job.
58
+ RUN_ID="${RUN_ID:-${MLP_TASK_ID:-${MLP_JOB_ID:-${MLP_RUN_ID:-${MLP_WORKER_0_PORT:-}}}}}"
59
+ if [[ -z "${RUN_NAME:-}" ]]; then
60
+ if [[ -z "${RUN_ID}" ]]; then
61
+ RUN_ID="$(date +%Y%m%d_%H%M%S)"
62
+ fi
63
+ RUN_ID="$(printf "%s" "${RUN_ID}" | tr -c 'A-Za-z0-9_-' '_')"
64
+ export RUN_NAME="lta_owt_t5_adaln_adamw_wd0p1_rollin_p50_randk0_3_uniformt_temp1_synct_gbs${GLOBAL_BATCH_SIZE}_${NNODES}node${NPROC_PER_NODE}gpu_1m_${RUN_ID}"
65
+ fi
66
+
67
+ export LOG_DIR="${LOG_DIR:-logs/elfaligned_t5tokenized_2node}"
68
+ export LOG_FILE="${LOG_FILE:-${LOG_DIR}/${RUN_NAME}.node${NODE_RANK}.log}"
69
+ mkdir -p "${LOG_DIR}" "runs/${RUN_NAME}"
70
+
71
+ echo "[launch] run_name=${RUN_NAME}"
72
+ echo "[launch] node_rank=${NODE_RANK}/${NNODES} nproc_per_node=${NPROC_PER_NODE} master=${MASTER_ADDR}:${MASTER_PORT}"
73
+ echo "[launch] total_gpus=${TOTAL_GPUS} global_batch_size=${GLOBAL_BATCH_SIZE} per_gpu_batch_size=${PER_GPU_BATCH_SIZE}"
74
+ echo "[launch] data_path=${DATA_PATH}"
75
+ echo "[launch] tokenizer=${TOKENIZER_PATH}"
76
+ echo "[launch] log_file=${LOG_FILE}"
77
+
78
+ torchrun \
79
+ --nnodes="${NNODES}" \
80
+ --nproc_per_node="${NPROC_PER_NODE}" \
81
+ --node_rank="${NODE_RANK}" \
82
+ --master_addr="${MASTER_ADDR}" \
83
+ --master_port="${MASTER_PORT}" \
84
+ train.py \
85
+ --data_path "${DATA_PATH}" \
86
+ --tokenized_hf \
87
+ --tokenized_pad_token pad \
88
+ --tokenizer_path "${TOKENIZER_PATH}" \
89
+ --save_dir "runs/${RUN_NAME}" \
90
+ --max_len 1024 \
91
+ --batch_size "${PER_GPU_BATCH_SIZE}" \
92
+ --global_batch_size "${GLOBAL_BATCH_SIZE}" \
93
+ --num_workers 8 \
94
+ --dataloader_prefetch_factor 4 \
95
+ --epochs 0 \
96
+ --total_steps 1000000 \
97
+ --warmup_steps 1000 \
98
+ --log_every 100 \
99
+ --eval_every 0 \
100
+ --save_every 10000 \
101
+ --latest_every 1000 \
102
+ --optimizer adamw \
103
+ --lr 6e-4 \
104
+ --lr_schedule cosine \
105
+ --min_lr 6e-5 \
106
+ --weight_decay 0.1 \
107
+ --output_weight_decay -1 \
108
+ --adamw_param_groups nanogpt \
109
+ --adam_beta1 0.9 \
110
+ --adam_beta2 0.999 \
111
+ --adam_eps 1e-8 \
112
+ --ema_decay 0.9999 \
113
+ --ema_start_step 0 \
114
+ --grad_clip 1.0 \
115
+ --seed 42 \
116
+ --d_model 768 \
117
+ --cond_dim 128 \
118
+ --n_layers 12 \
119
+ --n_heads 12 \
120
+ --dim_ff 3072 \
121
+ --dropout 0.0 \
122
+ --no-output_bias \
123
+ --norm_type rmsnorm \
124
+ --model_type ddit \
125
+ --ddit_mlp_type swiglu \
126
+ --state_format prob \
127
+ --bridge dirichlet \
128
+ --target_loss hard_ce \
129
+ --loss_t_weight_mode none \
130
+ --loss_t_min_weight 0.0 \
131
+ --rollout_train_prob 0.50 \
132
+ --rollout_train_time_mode sampled_path \
133
+ --rollout_train_steps 3 \
134
+ --rollout_train_steps_min 0 \
135
+ --rollout_train_infer_steps 1 \
136
+ --rollout_train_s_dist uniform \
137
+ --rollout_train_s_min_frac 0.0 \
138
+ --rollout_train_s_max_frac 0.25 \
139
+ --rollout_train_temp 1.0 \
140
+ --rollout_train_max_gamma 1.0 \
141
+ --rollout_train_corrupt_only \
142
+ --rollout_train_samplewise \
143
+ --rollout_train_selected_only \
144
+ --no-rollout_train_compute_always \
145
+ --rollout_train_sync_t \
146
+ --target_prob 1.0 \
147
+ --min_t 0.0 \
148
+ --max_t 1.0 \
149
+ --t_sampling_mode uniform \
150
+ --t_sampling_logit_mean -1.5 \
151
+ --t_sampling_logit_std 0.8 \
152
+ --t_sampling_eps 1e-4 \
153
+ --dual_t \
154
+ --corrupt_t_mode same \
155
+ --corrupt_min_t 0.0 \
156
+ --corrupt_max_t 1.0 \
157
+ --min_mask_ratio 1.0 \
158
+ --max_mask_ratio 1.0 \
159
+ --mask_mixture_original_prob 0.0 \
160
+ --mask_mixture_lowk_prob 0.0 \
161
+ --mask_mixture_lowcorrupt_prob 0.0 \
162
+ --mask_mixture_block_prob 0.0 \
163
+ --mask_mixture_all_prob 1.0 \
164
+ --wrong_token_replace_prob 1.0 \
165
+ --wrong_token_schedule linear_t \
166
+ --wrong_token_exp_k 1.0 \
167
+ --dirichlet_concentration_min 1.0 \
168
+ --dirichlet_concentration_max 1024 \
169
+ --dirichlet_endpoint_mode categorical_dual_t \
170
+ --dirichlet_semantic_t_mode same \
171
+ --dirichlet_semantic_t_value 0.0 \
172
+ --categorical_wrong_from_full_vocab \
173
+ --simplex_bridge_sampler dirichlet \
174
+ --eps 1e-8 \
175
+ --infer_steps 1024 \
176
+ --decode_damping 1.0 \
177
+ --max_gamma 1.0 \
178
+ --decode_solver flowmap \
179
+ --noise_init logistic_normal \
180
+ --bridge_noise_init logistic_normal \
181
+ --noise_sigma -1 \
182
+ --allow_tf32 \
183
+ --activation_checkpointing \
184
+ --activation_checkpoint_scope mlp \
185
+ --ddp_gradient_as_bucket_view \
186
+ 2>&1 | tee -a "${LOG_FILE}"
LTA_openwebtext_dualt/scripts/qwen_transformers_openai_server.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import os
4
+ import threading
5
+ import time
6
+ import uuid
7
+ from typing import Any
8
+
9
+ import torch
10
+ import uvicorn
11
+ from fastapi import FastAPI, Header, HTTPException
12
+ from pydantic import BaseModel
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+
15
+
16
+ class GenerateRequest(BaseModel):
17
+ model: str | None = None
18
+ messages: list[dict[str, Any]] | None = None
19
+ prompt: str | list[str] | None = None
20
+ max_tokens: int = 512
21
+ temperature: float = 0.7
22
+ top_p: float = 0.95
23
+ stop: str | list[str] | None = None
24
+ stream: bool = False
25
+ enable_thinking: bool | None = None
26
+ chat_template_kwargs: dict[str, Any] | None = None
27
+
28
+
29
+ def parse_args() -> argparse.Namespace:
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument("--model-path", required=True)
32
+ parser.add_argument("--served-model-name", default="qwen3.5-35b-a3b")
33
+ parser.add_argument("--host", default="127.0.0.1")
34
+ parser.add_argument("--port", type=int, default=8000)
35
+ parser.add_argument("--api-key", default=os.environ.get("API_KEY", ""))
36
+ parser.add_argument("--max-memory-gib", type=int, default=88)
37
+ parser.add_argument("--torch-dtype", default="bfloat16")
38
+ return parser.parse_args()
39
+
40
+
41
+ args = parse_args()
42
+ app = FastAPI()
43
+ lock = threading.Lock()
44
+
45
+ dtype = torch.bfloat16 if args.torch_dtype == "bfloat16" else torch.float16
46
+ tokenizer = AutoTokenizer.from_pretrained(
47
+ args.model_path,
48
+ trust_remote_code=True,
49
+ local_files_only=True,
50
+ )
51
+
52
+ max_memory = {
53
+ 0: f"{args.max_memory_gib}GiB",
54
+ 1: f"{args.max_memory_gib}GiB",
55
+ "cpu": "128GiB",
56
+ }
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ args.model_path,
59
+ trust_remote_code=True,
60
+ local_files_only=True,
61
+ torch_dtype=dtype,
62
+ device_map="auto",
63
+ max_memory=max_memory,
64
+ low_cpu_mem_usage=True,
65
+ )
66
+ model.eval()
67
+
68
+
69
+ def require_auth(authorization: str | None) -> None:
70
+ if not args.api_key:
71
+ return
72
+ expected = f"Bearer {args.api_key}"
73
+ if authorization != expected:
74
+ raise HTTPException(status_code=401, detail="Unauthorized")
75
+
76
+
77
+ def make_prompt(req: GenerateRequest) -> str:
78
+ if req.messages is not None:
79
+ template_kwargs = dict(req.chat_template_kwargs or {})
80
+ if req.enable_thinking is not None:
81
+ template_kwargs["enable_thinking"] = req.enable_thinking
82
+ return tokenizer.apply_chat_template(
83
+ req.messages,
84
+ tokenize=False,
85
+ add_generation_prompt=True,
86
+ **template_kwargs,
87
+ )
88
+ if isinstance(req.prompt, list):
89
+ return req.prompt[0]
90
+ if isinstance(req.prompt, str):
91
+ return req.prompt
92
+ raise HTTPException(status_code=400, detail="messages or prompt is required")
93
+
94
+
95
+ def strip_stop(text: str, stop: str | list[str] | None) -> str:
96
+ stops = [stop] if isinstance(stop, str) else (stop or [])
97
+ cut = len(text)
98
+ for s in stops:
99
+ if not s:
100
+ continue
101
+ idx = text.find(s)
102
+ if idx >= 0:
103
+ cut = min(cut, idx)
104
+ return text[:cut]
105
+
106
+
107
+ def generate_text(req: GenerateRequest) -> tuple[str, int, int]:
108
+ prompt = make_prompt(req)
109
+ inputs = tokenizer(prompt, return_tensors="pt")
110
+ input_len = int(inputs.input_ids.shape[-1])
111
+ device = next(model.parameters()).device
112
+ inputs = {k: v.to(device) for k, v in inputs.items()}
113
+ do_sample = req.temperature is not None and req.temperature > 0
114
+ kwargs = {
115
+ "max_new_tokens": max(1, int(req.max_tokens)),
116
+ "do_sample": do_sample,
117
+ "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
118
+ "eos_token_id": tokenizer.eos_token_id,
119
+ }
120
+ if do_sample:
121
+ kwargs["temperature"] = float(req.temperature)
122
+ kwargs["top_p"] = float(req.top_p)
123
+ with lock, torch.inference_mode():
124
+ output = model.generate(**inputs, **kwargs)
125
+ new_tokens = output[0, input_len:]
126
+ text = tokenizer.decode(new_tokens, skip_special_tokens=True)
127
+ text = strip_stop(text, req.stop)
128
+ return text, input_len, int(new_tokens.numel())
129
+
130
+
131
+ @app.get("/healthz")
132
+ def healthz() -> dict[str, Any]:
133
+ return {"ok": True, "model": args.served_model_name}
134
+
135
+
136
+ @app.get("/v1/models")
137
+ def models(authorization: str | None = Header(default=None)) -> dict[str, Any]:
138
+ require_auth(authorization)
139
+ return {
140
+ "object": "list",
141
+ "data": [
142
+ {
143
+ "id": args.served_model_name,
144
+ "object": "model",
145
+ "created": 0,
146
+ "owned_by": "local",
147
+ }
148
+ ],
149
+ }
150
+
151
+
152
+ @app.post("/v1/chat/completions")
153
+ def chat(req: GenerateRequest, authorization: str | None = Header(default=None)) -> dict[str, Any]:
154
+ require_auth(authorization)
155
+ if req.stream:
156
+ raise HTTPException(status_code=400, detail="stream=true is not implemented")
157
+ text, prompt_tokens, completion_tokens = generate_text(req)
158
+ return {
159
+ "id": f"chatcmpl-{uuid.uuid4().hex}",
160
+ "object": "chat.completion",
161
+ "created": int(time.time()),
162
+ "model": req.model or args.served_model_name,
163
+ "choices": [
164
+ {
165
+ "index": 0,
166
+ "message": {"role": "assistant", "content": text},
167
+ "finish_reason": "stop",
168
+ }
169
+ ],
170
+ "usage": {
171
+ "prompt_tokens": prompt_tokens,
172
+ "completion_tokens": completion_tokens,
173
+ "total_tokens": prompt_tokens + completion_tokens,
174
+ },
175
+ }
176
+
177
+
178
+ @app.post("/v1/completions")
179
+ def completions(req: GenerateRequest, authorization: str | None = Header(default=None)) -> dict[str, Any]:
180
+ require_auth(authorization)
181
+ if req.stream:
182
+ raise HTTPException(status_code=400, detail="stream=true is not implemented")
183
+ text, prompt_tokens, completion_tokens = generate_text(req)
184
+ return {
185
+ "id": f"cmpl-{uuid.uuid4().hex}",
186
+ "object": "text_completion",
187
+ "created": int(time.time()),
188
+ "model": req.model or args.served_model_name,
189
+ "choices": [{"text": text, "index": 0, "finish_reason": "stop"}],
190
+ "usage": {
191
+ "prompt_tokens": prompt_tokens,
192
+ "completion_tokens": completion_tokens,
193
+ "total_tokens": prompt_tokens + completion_tokens,
194
+ },
195
+ }
196
+
197
+
198
+ if __name__ == "__main__":
199
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
LTA_openwebtext_dualt/scripts/run_lta_owt_bert_absrope_time4_dirichlet_len1024_Cv_to_2v_8gpu_1m_mask1_sameT_save10k.sh ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ # OWT raw text + BERT tokenizer, FLM wrapped stream:
7
+ # [CLS] + 1022 payload tokens + [SEP]
8
+ #
9
+ # Backbone:
10
+ # ddit_elf = RMSNorm/SwiGLU/QK-norm + RoPE + 4 prefix time tokens.
11
+ # We also add learned absolute position embeddings before RoPE.
12
+ #
13
+ # Bridge:
14
+ # Dirichlet C=V->2V, mask_ratio=1.0, model t and corruption t are shared.
15
+
16
+ export DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext}"
17
+ export TEXT_COLUMN="${TEXT_COLUMN:-text}"
18
+ export OPENWEBTEXT_SPLIT="${OPENWEBTEXT_SPLIT:-train_minus_100k}"
19
+ export TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/workspace/imagenet_handoff_20260327/nlp_dts_light/assets/distilbert-base-uncased/tokenizer.json}"
20
+ export TOKENIZED_HF=0
21
+ export WRAP_MODE="${WRAP_MODE:-stream}"
22
+
23
+ export VOCAB_SIZE="${VOCAB_SIZE:-30522}"
24
+ export CMIN="${CMIN:-30522}"
25
+ export CMAX="${CMAX:-61044}"
26
+
27
+ export MODEL_TYPE=ddit_elf
28
+ export ELF_NUM_TIME_TOKENS="${ELF_NUM_TIME_TOKENS:-4}"
29
+ export ELF_NUM_MODEL_MODE_TOKENS="${ELF_NUM_MODEL_MODE_TOKENS:-0}"
30
+ export QK_NORM="${QK_NORM:-1}"
31
+ export ABS_POS_EMBED=1
32
+ export CORRUPT_T_MODE=same
33
+ export MIN_MASK_RATIO=1.0
34
+ export MAX_MASK_RATIO=1.0
35
+ export CATEGORICAL_WRONG_PROB_FLOOR="${CATEGORICAL_WRONG_PROB_FLOOR:-0.0}"
36
+
37
+ _ngpus_avail=$(nvidia-smi --query-gpu=index --format=csv,noheader 2>/dev/null | wc -l || echo 1)
38
+ if [[ "${_ngpus_avail}" -le 0 ]]; then _ngpus_avail=1; fi
39
+ _default_cvd=$(seq -s, 0 $((_ngpus_avail - 1)))
40
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-${_default_cvd}}"
41
+ IFS=',' read -ra _cvd_arr <<< "${CUDA_VISIBLE_DEVICES}"
42
+ export NPROC_PER_NODE="${NPROC_PER_NODE:-${#_cvd_arr[@]}}"
43
+ unset _ngpus_avail _default_cvd _cvd_arr
44
+ export NNODES="${NNODES:-${MLP_WORKER_NUM:-1}}"
45
+ export NODE_RANK="${NODE_RANK:-${MLP_ROLE_INDEX:-0}}"
46
+ export MASTER_ADDR="${MASTER_ADDR:-${MLP_WORKER_0_HOST:-127.0.0.1}}"
47
+ export MASTER_PORT="${MASTER_PORT:-${MLP_WORKER_0_PORT:-29500}}"
48
+ export GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
49
+ export PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
50
+ export TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
51
+ export WARMUP_STEPS="${WARMUP_STEPS:-2500}"
52
+ export SAVE_EVERY="${SAVE_EVERY:-10000}"
53
+ export LATEST_EVERY="${LATEST_EVERY:-1000}"
54
+ export LOG_EVERY="${LOG_EVERY:-100}"
55
+
56
+ export DATE_TAG="${DATE_TAG:-$(date +%Y%m%d)}"
57
+ export RUN_NAME="${RUN_NAME:-lta_owt_bert_absrope_time4_dirichlet_len1024_Cv_to_2v_mask1_sameT_gbs512_b32_8gpu_1m_save10k_${DATE_TAG}}"
58
+
59
+ export WATCH_ENABLED="${WATCH_ENABLED:-1}"
60
+ export WATCH_STEP_INTERVAL="${WATCH_STEP_INTERVAL:-10000}"
61
+ export WATCH_N_SAMPLES="${WATCH_N_SAMPLES:-128}"
62
+ export WATCH_CUDA_VISIBLE_DEVICES="${WATCH_CUDA_VISIBLE_DEVICES:-7}"
63
+ export WATCH_ENDPOINT_TEMP="${WATCH_ENDPOINT_TEMP:-1.45}"
64
+ export WATCH_ENDPOINT_TOP_P="${WATCH_ENDPOINT_TOP_P:-0.95}"
65
+ export WATCH_GUMBEL_TAU_START="${WATCH_GUMBEL_TAU_START:-1.0}"
66
+ export WATCH_GUMBEL_TAU_END="${WATCH_GUMBEL_TAU_END:-0.2}"
67
+ export WATCH_OUT_BASE="${WATCH_OUT_BASE:-docs/lta_samples/metrics_${DATE_TAG}/owt_bert_absrope_time4_Cv_to_2v_mask1_sameT_sde_gumbel_topp${WATCH_ENDPOINT_TOP_P//./p}_tau${WATCH_GUMBEL_TAU_START//./p}_to_${WATCH_GUMBEL_TAU_END//./p}_blend_c${CMIN}_${CMAX}_n${WATCH_N_SAMPLES}/${RUN_NAME}}"
68
+ export WATCH_LOG_DIR="${WATCH_LOG_DIR:-logs/owt_bert_absrope_time4_Cv_to_2v_mask1_sameT_gumbel_sde_watch}"
69
+
70
+ bash scripts/run_lta_owt_dirichlet_len1024_Cv_to_2v_8gpu_save1k_with_gumbel_watch.sh
LTA_openwebtext_dualt/scripts/run_train8_rollin_focused_pilots_4gpu.sh ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
7
+ export TOKENIZERS_PARALLELISM=false
8
+ export PYTHONUNBUFFERED=1
9
+
10
+ BASE_CACHE="${BASE_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks}"
11
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
12
+ MAX_LEN="${MAX_LEN:-256}"
13
+ N_SAMPLES="${N_SAMPLES:-64}"
14
+ INFER_STEPS="${INFER_STEPS:-128}"
15
+ STEP_CHUNK="${STEP_CHUNK:-500}"
16
+ MAX_TOTAL_STEPS="${MAX_TOTAL_STEPS:-12000}"
17
+ STOP_EXACT_COUNT="${STOP_EXACT_COUNT:-64}"
18
+ STOP_EXACT_REF_COUNT="${STOP_EXACT_REF_COUNT:-8}"
19
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-128}"
20
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
21
+ GROUP_STAMP="${GROUP_STAMP:-$(date +%Y%m%d_%H%M%S)}"
22
+ OUT_ROOT="${OUT_ROOT:-docs/lta_samples/metrics_20260517/rollin_focused_len${MAX_LEN}_bs512_ode128_${GROUP_STAMP}}"
23
+ DRIVER_LOG="${DRIVER_LOG:-logs/rollin_focused_4gpu/${GROUP_STAMP}.log}"
24
+ CURVE_CSV="${CURVE_CSV:-${OUT_ROOT}/hit_ratio_curve.csv}"
25
+ mkdir -p "$(dirname "${DRIVER_LOG}")" "${OUT_ROOT}"
26
+
27
+ cache="${BASE_CACHE}/gpt2_len${MAX_LEN}_train8_compact_overfit"
28
+ vocab_size="$(
29
+ python - "$cache" <<'PY'
30
+ import json
31
+ import sys
32
+ from pathlib import Path
33
+ meta = json.loads((Path(sys.argv[1]) / "meta.json").read_text())
34
+ print(int(meta.get("compact_vocab_size", meta.get("vocab_size"))))
35
+ PY
36
+ )"
37
+
38
+ if [[ ! -f "${CURVE_CSV}" ]]; then
39
+ echo "config,ckpt_step,train_views_seen,train_tokens_seen,token_acc_mean,exact_count,exact_ref_count,exact_ref_hits" > "${CURVE_CSV}"
40
+ fi
41
+
42
+ latest_step() {
43
+ local run_name="$1"
44
+ python - "$run_name" <<'PY'
45
+ import re
46
+ import sys
47
+ from pathlib import Path
48
+ run = Path("runs") / sys.argv[1]
49
+ steps = []
50
+ for path in run.glob("step_*.pt"):
51
+ m = re.search(r"step_(\d+)\.pt$", path.name)
52
+ if m:
53
+ steps.append(int(m.group(1)))
54
+ print(max(steps) if steps else 0)
55
+ PY
56
+ }
57
+
58
+ free_port() {
59
+ python - <<'PY'
60
+ import socket
61
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
62
+ s.bind(("127.0.0.1", 0))
63
+ print(s.getsockname()[1])
64
+ PY
65
+ }
66
+
67
+ eval_latest() {
68
+ local config="$1"
69
+ local run_name="$2"
70
+ local target_step="$3"
71
+ local out_dir="${OUT_ROOT}/${config}/step_${target_step}"
72
+ mkdir -p "${out_dir}"
73
+ CUDA_VISIBLE_DEVICES="${EVAL_CUDA_VISIBLE_DEVICES:-0}" python scripts/eval_train8_decode_acc.py \
74
+ --runs_glob "runs/${run_name}" \
75
+ --data_dir "${cache}" \
76
+ --tokenizer_path "${TOKENIZER_PATH}" \
77
+ --out_dir "${out_dir}" \
78
+ --max_len "${MAX_LEN}" \
79
+ --n_samples "${N_SAMPLES}" \
80
+ --batch_size "${N_SAMPLES}" \
81
+ --latest_only \
82
+ --endpoint_softenings none \
83
+ --steps "${INFER_STEPS}" \
84
+ --decode_rule flowmap \
85
+ --time_schedule logit_normal \
86
+ --time_logit_mean -1.5 \
87
+ --time_logit_std 0.8 \
88
+ --model_t_mode post \
89
+ --c_min 1 \
90
+ --c_max 512 \
91
+ --late_temp 1.0 \
92
+ --final_from state \
93
+ --final_decode argmax
94
+ python - "$config" "$out_dir" "$N_SAMPLES" "$GLOBAL_BATCH_SIZE" "$MAX_LEN" "$CURVE_CSV" "$STOP_EXACT_COUNT" "$STOP_EXACT_REF_COUNT" <<'PY'
95
+ import json
96
+ import sys
97
+ from pathlib import Path
98
+ config = sys.argv[1]
99
+ out = Path(sys.argv[2])
100
+ n = int(sys.argv[3])
101
+ global_batch = int(sys.argv[4])
102
+ max_len = int(sys.argv[5])
103
+ curve = Path(sys.argv[6])
104
+ stop_exact_count = int(sys.argv[7])
105
+ stop_exact_ref_count = int(sys.argv[8])
106
+ row = json.loads((out / "decode_token_acc.jsonl").read_text().splitlines()[-1])
107
+ views = int(row["ckpt_step"]) * global_batch
108
+ tokens = views * max_len
109
+ print(
110
+ "RESULT "
111
+ f"config={config} ckpt_step={row['ckpt_step']} views={views} "
112
+ f"token_acc={row['token_acc_mean']:.4f} exact={row['exact_count']}/{n} "
113
+ f"exact_refs={row['exact_ref_count']} hits={row['exact_ref_hits']}",
114
+ flush=True,
115
+ )
116
+ with curve.open("a", encoding="utf-8") as f:
117
+ f.write(
118
+ f"{config},{row['ckpt_step']},{views},{tokens},{row['token_acc_mean']},"
119
+ f"{row['exact_count']},{row['exact_ref_count']},\"{row['exact_ref_hits']}\"\n"
120
+ )
121
+ raise SystemExit(
122
+ 0
123
+ if int(row["exact_count"]) >= stop_exact_count
124
+ and int(row["exact_ref_count"]) >= stop_exact_ref_count
125
+ else 1
126
+ )
127
+ PY
128
+ }
129
+
130
+ reset_defaults() {
131
+ export MIN_MASK_RATIO=1.0
132
+ export MAX_MASK_RATIO=1.0
133
+ export MASK_MIXTURE_LOWK_PROB=0.0
134
+ export MASK_MIXTURE_ALL_PROB=1.0
135
+ export LOWK_CLEAN_TOKENS=0
136
+ export CLEAN_STATE_MODE=onehot
137
+ export TARGET_LOSS=hard_ce
138
+ export DIRICHLET_CONCENTRATION_MIN=1.0
139
+ export DIRICHLET_CONCENTRATION_MAX=1024
140
+ export SIMPLEX_BRIDGE_SAMPLER=dirichlet
141
+ export CATEGORICAL_WRONG_UNIGRAM_PROB=0.0
142
+ export CATEGORICAL_WRONG_UNIGRAM_SHARED_PROB=0.0
143
+ export CATEGORICAL_WRONG_PROB_FLOOR=0.0
144
+ export ROLLOUT_TRAIN_PROB=0.0
145
+ export ROLLOUT_TRAIN_STEPS=1
146
+ export ROLLOUT_TRAIN_INFER_STEPS=64
147
+ export ROLLOUT_TRAIN_TEMP=1.45
148
+ export ROLLOUT_TRAIN_MAX_GAMMA=1.0
149
+ export ROLLOUT_TRAIN_CORRUPT_ONLY=1
150
+ export ROLLOUT_TRAIN_SAMPLEWISE=1
151
+ export ROLLOUT_TRAIN_COMPUTE_ALWAYS=0
152
+ }
153
+
154
+ configure() {
155
+ local config="$1"
156
+ reset_defaults
157
+ case "${config}" in
158
+ rollin_p50_s4_i32)
159
+ export ROLLOUT_TRAIN_PROB=0.50
160
+ export ROLLOUT_TRAIN_STEPS=4
161
+ export ROLLOUT_TRAIN_INFER_STEPS=32
162
+ ;;
163
+ rollin_p75_s4_i32)
164
+ export ROLLOUT_TRAIN_PROB=0.75
165
+ export ROLLOUT_TRAIN_STEPS=4
166
+ export ROLLOUT_TRAIN_INFER_STEPS=32
167
+ ;;
168
+ rollin_p100_s4_i32)
169
+ export ROLLOUT_TRAIN_PROB=1.00
170
+ export ROLLOUT_TRAIN_STEPS=4
171
+ export ROLLOUT_TRAIN_INFER_STEPS=32
172
+ ;;
173
+ rollin_p50_s8_i64)
174
+ export ROLLOUT_TRAIN_PROB=0.50
175
+ export ROLLOUT_TRAIN_STEPS=8
176
+ export ROLLOUT_TRAIN_INFER_STEPS=64
177
+ ;;
178
+ rollin_p75_s8_i64)
179
+ export ROLLOUT_TRAIN_PROB=0.75
180
+ export ROLLOUT_TRAIN_STEPS=8
181
+ export ROLLOUT_TRAIN_INFER_STEPS=64
182
+ ;;
183
+ rollin_p50_s4_i32_temp1p0)
184
+ export ROLLOUT_TRAIN_PROB=0.50
185
+ export ROLLOUT_TRAIN_STEPS=4
186
+ export ROLLOUT_TRAIN_INFER_STEPS=32
187
+ export ROLLOUT_TRAIN_TEMP=1.0
188
+ ;;
189
+ *)
190
+ echo "unknown config: ${config}" >&2
191
+ return 2
192
+ ;;
193
+ esac
194
+ }
195
+
196
+ configs=(
197
+ rollin_p50_s4_i32
198
+ rollin_p75_s4_i32
199
+ rollin_p100_s4_i32
200
+ rollin_p50_s8_i64
201
+ rollin_p75_s8_i64
202
+ rollin_p50_s4_i32_temp1p0
203
+ )
204
+
205
+ echo "[rollin-focused] start stamp=${GROUP_STAMP} len=${MAX_LEN} vocab=${vocab_size} out=${OUT_ROOT}" | tee -a "${DRIVER_LOG}"
206
+
207
+ round_idx=0
208
+ while :; do
209
+ round_idx=$((round_idx + 1))
210
+ active=0
211
+ echo "[rollin-focused] round=${round_idx} $(date)" | tee -a "${DRIVER_LOG}"
212
+ for config in "${configs[@]}"; do
213
+ run_name="train8_rollin_focused_len${MAX_LEN}_${config}_${GROUP_STAMP}"
214
+ done_flag="${OUT_ROOT}/${config}/DONE"
215
+ if [[ -f "${done_flag}" ]]; then
216
+ echo "[rollin-focused] skip done config=${config}" | tee -a "${DRIVER_LOG}"
217
+ continue
218
+ fi
219
+ step_now="$(latest_step "${run_name}")"
220
+ if [[ "${step_now}" -ge "${MAX_TOTAL_STEPS}" ]]; then
221
+ echo "[rollin-focused] capped config=${config} step=${step_now}" | tee -a "${DRIVER_LOG}"
222
+ continue
223
+ fi
224
+ active=1
225
+ target_step=$((step_now + STEP_CHUNK))
226
+ if [[ "${target_step}" -gt "${MAX_TOTAL_STEPS}" ]]; then
227
+ target_step="${MAX_TOTAL_STEPS}"
228
+ fi
229
+ resume_path=""
230
+ if [[ -f "runs/${run_name}/latest.pt" ]]; then
231
+ resume_path="runs/${run_name}/latest.pt"
232
+ fi
233
+ configure "${config}"
234
+ echo "[rollin-focused] train config=${config} from=${step_now} to=${target_step} rollout=${ROLLOUT_TRAIN_PROB}/s${ROLLOUT_TRAIN_STEPS}/i${ROLLOUT_TRAIN_INFER_STEPS}/temp${ROLLOUT_TRAIN_TEMP}" | tee -a "${DRIVER_LOG}"
235
+ CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}" \
236
+ NPROC_PER_NODE="${NPROC_PER_NODE:-4}" \
237
+ MASTER_PORT="$(free_port)" \
238
+ OWT_CHUNK_CACHE_DIR="${cache}" \
239
+ OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK:-64}" \
240
+ MAX_LEN="${MAX_LEN}" \
241
+ VOCAB_SIZE_OVERRIDE="${vocab_size}" \
242
+ D_MODEL="${D_MODEL:-192}" \
243
+ COND_DIM="${COND_DIM:-64}" \
244
+ N_LAYERS="${N_LAYERS:-3}" \
245
+ N_HEADS="${N_HEADS:-3}" \
246
+ DIM_FF="${DIM_FF:-768}" \
247
+ TOTAL_STEPS="${target_step}" \
248
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE}" \
249
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE}" \
250
+ NUM_WORKERS="${NUM_WORKERS:-0}" \
251
+ LOG_EVERY="${LOG_EVERY:-100}" \
252
+ SAVE_EVERY="${STEP_CHUNK}" \
253
+ LATEST_EVERY="${STEP_CHUNK}" \
254
+ WARMUP_STEPS="${WARMUP_STEPS:-10}" \
255
+ LEARNING_RATE="${LEARNING_RATE:-0.002}" \
256
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}" \
257
+ MUON_IMPL="${MUON_IMPL:-legacy}" \
258
+ OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}" \
259
+ RUN_NAME="${run_name}" \
260
+ RESUME_PATH="${resume_path}" \
261
+ bash scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh
262
+ echo "[rollin-focused] eval config=${config} step=${target_step}" | tee -a "${DRIVER_LOG}"
263
+ if eval_latest "${config}" "${run_name}" "${target_step}" | tee -a "${DRIVER_LOG}"; then
264
+ touch "${done_flag}"
265
+ echo "[rollin-focused] done config=${config}" | tee -a "${DRIVER_LOG}"
266
+ fi
267
+ done
268
+ if [[ "${active}" -eq 0 ]]; then
269
+ echo "[rollin-focused] all capped/done $(date)" | tee -a "${DRIVER_LOG}"
270
+ break
271
+ fi
272
+ done
LTA_openwebtext_dualt/scripts/run_train8_selected_long20k_4gpu.sh ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
7
+ export TOKENIZERS_PARALLELISM=false
8
+ export PYTHONUNBUFFERED=1
9
+
10
+ BASE_CACHE="${BASE_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks}"
11
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
12
+ MAX_LEN="${MAX_LEN:-256}"
13
+ N_SAMPLES="${N_SAMPLES:-64}"
14
+ INFER_STEPS="${INFER_STEPS:-128}"
15
+ STEP_CHUNK="${STEP_CHUNK:-1000}"
16
+ MAX_TOTAL_STEPS="${MAX_TOTAL_STEPS:-20000}"
17
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-128}"
18
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
19
+ GROUP_STAMP="${GROUP_STAMP:-$(date +%Y%m%d_%H%M%S)}"
20
+ OUT_ROOT="${OUT_ROOT:-docs/lta_samples/metrics_20260517/selected_long20k_len${MAX_LEN}_bs512_ode128_${GROUP_STAMP}}"
21
+ DRIVER_LOG="${DRIVER_LOG:-logs/selected_long20k_4gpu/${GROUP_STAMP}.log}"
22
+ CURVE_CSV="${CURVE_CSV:-${OUT_ROOT}/hit_ratio_curve.csv}"
23
+ mkdir -p "$(dirname "${DRIVER_LOG}")" "${OUT_ROOT}"
24
+
25
+ cache="${BASE_CACHE}/gpt2_len${MAX_LEN}_train8_compact_overfit"
26
+ vocab_size="$(
27
+ python - "$cache" <<'PY'
28
+ import json
29
+ import sys
30
+ from pathlib import Path
31
+ meta = json.loads((Path(sys.argv[1]) / "meta.json").read_text())
32
+ print(int(meta.get("compact_vocab_size", meta.get("vocab_size"))))
33
+ PY
34
+ )"
35
+
36
+ if [[ ! -f "${CURVE_CSV}" ]]; then
37
+ echo "config,run_name,ckpt_step,train_views_seen,train_tokens_seen,token_acc_mean,exact_count,exact_ref_count,exact_ref_hits" > "${CURVE_CSV}"
38
+ fi
39
+
40
+ latest_step() {
41
+ local run_name="$1"
42
+ python - "$run_name" <<'PY'
43
+ import re
44
+ import sys
45
+ from pathlib import Path
46
+ run = Path("runs") / sys.argv[1]
47
+ steps = []
48
+ for path in run.glob("step_*.pt"):
49
+ m = re.search(r"step_(\d+)\.pt$", path.name)
50
+ if m:
51
+ steps.append(int(m.group(1)))
52
+ print(max(steps) if steps else 0)
53
+ PY
54
+ }
55
+
56
+ free_port() {
57
+ python - <<'PY'
58
+ import socket
59
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
60
+ s.bind(("127.0.0.1", 0))
61
+ print(s.getsockname()[1])
62
+ PY
63
+ }
64
+
65
+ eval_latest() {
66
+ local config="$1"
67
+ local run_name="$2"
68
+ local target_step="$3"
69
+ local out_dir="${OUT_ROOT}/${config}/step_${target_step}"
70
+ mkdir -p "${out_dir}"
71
+ CUDA_VISIBLE_DEVICES="${EVAL_CUDA_VISIBLE_DEVICES:-0}" python scripts/eval_train8_decode_acc.py \
72
+ --runs_glob "runs/${run_name}" \
73
+ --data_dir "${cache}" \
74
+ --tokenizer_path "${TOKENIZER_PATH}" \
75
+ --out_dir "${out_dir}" \
76
+ --max_len "${MAX_LEN}" \
77
+ --n_samples "${N_SAMPLES}" \
78
+ --batch_size "${N_SAMPLES}" \
79
+ --latest_only \
80
+ --endpoint_softenings none \
81
+ --steps "${INFER_STEPS}" \
82
+ --decode_rule flowmap \
83
+ --time_schedule logit_normal \
84
+ --time_logit_mean -1.5 \
85
+ --time_logit_std 0.8 \
86
+ --model_t_mode post \
87
+ --c_min 1 \
88
+ --c_max 512 \
89
+ --late_temp 1.0 \
90
+ --final_from state \
91
+ --final_decode argmax
92
+ python - "$config" "$run_name" "$out_dir" "$N_SAMPLES" "$GLOBAL_BATCH_SIZE" "$MAX_LEN" "$CURVE_CSV" <<'PY'
93
+ import json
94
+ import sys
95
+ from pathlib import Path
96
+ config = sys.argv[1]
97
+ run_name = sys.argv[2]
98
+ out = Path(sys.argv[3])
99
+ n = int(sys.argv[4])
100
+ global_batch = int(sys.argv[5])
101
+ max_len = int(sys.argv[6])
102
+ curve = Path(sys.argv[7])
103
+ row = json.loads((out / "decode_token_acc.jsonl").read_text().splitlines()[-1])
104
+ views = int(row["ckpt_step"]) * global_batch
105
+ tokens = views * max_len
106
+ print(
107
+ "RESULT "
108
+ f"config={config} run={run_name} ckpt_step={row['ckpt_step']} views={views} "
109
+ f"token_acc={row['token_acc_mean']:.4f} exact={row['exact_count']}/{n} "
110
+ f"exact_refs={row['exact_ref_count']} hits={row['exact_ref_hits']}",
111
+ flush=True,
112
+ )
113
+ with curve.open("a", encoding="utf-8") as f:
114
+ f.write(
115
+ f"{config},{run_name},{row['ckpt_step']},{views},{tokens},{row['token_acc_mean']},"
116
+ f"{row['exact_count']},{row['exact_ref_count']},\"{row['exact_ref_hits']}\"\n"
117
+ )
118
+ PY
119
+ }
120
+
121
+ reset_common() {
122
+ export MIN_MASK_RATIO=1.0
123
+ export MAX_MASK_RATIO=1.0
124
+ export MASK_MIXTURE_LOWK_PROB=0.0
125
+ export MASK_MIXTURE_ALL_PROB=1.0
126
+ export LOWK_CLEAN_TOKENS=0
127
+ export CLEAN_STATE_MODE=onehot
128
+ export TARGET_LOSS=hard_ce
129
+ export DIRICHLET_CONCENTRATION_MIN=1.0
130
+ export DIRICHLET_CONCENTRATION_MAX=1024
131
+ export SIMPLEX_BRIDGE_SAMPLER=dirichlet
132
+ export CATEGORICAL_WRONG_PROB_FLOOR=0.0
133
+ export CATEGORICAL_WRONG_UNIGRAM_PROB=0.0
134
+ export CATEGORICAL_WRONG_UNIGRAM_SHARED_PROB=0.0
135
+ export ROLLOUT_TRAIN_PROB=0.0
136
+ export ROLLOUT_TRAIN_STEPS=1
137
+ export ROLLOUT_TRAIN_INFER_STEPS=64
138
+ export ROLLOUT_TRAIN_TEMP=1.45
139
+ export ROLLOUT_TRAIN_MAX_GAMMA=1.0
140
+ export ROLLOUT_TRAIN_CORRUPT_ONLY=1
141
+ export ROLLOUT_TRAIN_SAMPLEWISE=1
142
+ export ROLLOUT_TRAIN_COMPUTE_ALWAYS=0
143
+ export ROLLOUT_TRAIN_SYNC_T=0
144
+ }
145
+
146
+ configure() {
147
+ local config="$1"
148
+ reset_common
149
+ case "${config}" in
150
+ baseline_allcorrupt)
151
+ RUN_NAME_SELECTED="train8_noisegeo_len256_allcorrupt_fullvocab_dirC1_1024_20260517_163805"
152
+ ;;
153
+ rollin_p50_s4_old)
154
+ RUN_NAME_SELECTED="train8_rollin_len256_rollin_p50_s4_i32_20260517_171654"
155
+ export ROLLOUT_TRAIN_PROB=0.50
156
+ export ROLLOUT_TRAIN_STEPS=4
157
+ export ROLLOUT_TRAIN_INFER_STEPS=32
158
+ export ROLLOUT_TRAIN_TEMP=1.45
159
+ export ROLLOUT_TRAIN_SYNC_T=0
160
+ ;;
161
+ *)
162
+ echo "unknown config: ${config}" >&2
163
+ return 2
164
+ ;;
165
+ esac
166
+ }
167
+
168
+ configs=(
169
+ baseline_allcorrupt
170
+ rollin_p50_s4_old
171
+ )
172
+
173
+ echo "[selected-long20k] start stamp=${GROUP_STAMP} len=${MAX_LEN} vocab=${vocab_size} out=${OUT_ROOT}" | tee -a "${DRIVER_LOG}"
174
+ while :; do
175
+ active=0
176
+ for config in "${configs[@]}"; do
177
+ configure "${config}"
178
+ run_name="${RUN_NAME_SELECTED}"
179
+ step_now="$(latest_step "${run_name}")"
180
+ if [[ "${step_now}" -ge "${MAX_TOTAL_STEPS}" ]]; then
181
+ echo "[selected-long20k] capped config=${config} run=${run_name} step=${step_now}" | tee -a "${DRIVER_LOG}"
182
+ continue
183
+ fi
184
+ active=1
185
+ target_step=$((step_now + STEP_CHUNK))
186
+ if [[ "${target_step}" -gt "${MAX_TOTAL_STEPS}" ]]; then
187
+ target_step="${MAX_TOTAL_STEPS}"
188
+ fi
189
+ resume_path=""
190
+ if [[ -f "runs/${run_name}/latest.pt" ]]; then
191
+ resume_path="runs/${run_name}/latest.pt"
192
+ fi
193
+ echo "[selected-long20k] train config=${config} run=${run_name} from=${step_now} to=${target_step} rollout=${ROLLOUT_TRAIN_PROB}/s${ROLLOUT_TRAIN_STEPS}/i${ROLLOUT_TRAIN_INFER_STEPS}" | tee -a "${DRIVER_LOG}"
194
+ CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}" \
195
+ NPROC_PER_NODE="${NPROC_PER_NODE:-4}" \
196
+ MASTER_PORT="$(free_port)" \
197
+ OWT_CHUNK_CACHE_DIR="${cache}" \
198
+ OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK:-64}" \
199
+ MAX_LEN="${MAX_LEN}" \
200
+ VOCAB_SIZE_OVERRIDE="${vocab_size}" \
201
+ D_MODEL="${D_MODEL:-192}" \
202
+ COND_DIM="${COND_DIM:-64}" \
203
+ N_LAYERS="${N_LAYERS:-3}" \
204
+ N_HEADS="${N_HEADS:-3}" \
205
+ DIM_FF="${DIM_FF:-768}" \
206
+ TOTAL_STEPS="${target_step}" \
207
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE}" \
208
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE}" \
209
+ NUM_WORKERS="${NUM_WORKERS:-0}" \
210
+ LOG_EVERY="${LOG_EVERY:-100}" \
211
+ SAVE_EVERY="${STEP_CHUNK}" \
212
+ LATEST_EVERY="${STEP_CHUNK}" \
213
+ WARMUP_STEPS="${WARMUP_STEPS:-10}" \
214
+ LEARNING_RATE="${LEARNING_RATE:-0.002}" \
215
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}" \
216
+ MUON_IMPL="${MUON_IMPL:-legacy}" \
217
+ OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}" \
218
+ RUN_NAME="${run_name}" \
219
+ RESUME_PATH="${resume_path}" \
220
+ bash scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh
221
+ echo "[selected-long20k] eval config=${config} step=${target_step}" | tee -a "${DRIVER_LOG}"
222
+ eval_latest "${config}" "${run_name}" "${target_step}" | tee -a "${DRIVER_LOG}"
223
+ done
224
+ if [[ "${active}" -eq 0 ]]; then
225
+ echo "[selected-long20k] all capped $(date)" | tee -a "${DRIVER_LOG}"
226
+ break
227
+ fi
228
+ done
LTA_openwebtext_dualt/scripts/score_lta_decode_strategy.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import torch
10
+
11
+ REPO_ROOT = Path(__file__).resolve().parents[1]
12
+ if str(REPO_ROOT) not in sys.path:
13
+ sys.path.insert(0, str(REPO_ROOT))
14
+
15
+ from eval import build_model_from_ckpt
16
+ from flowtext_lab.decode import sample_noise_simplex
17
+ from flowtext_lab.genppl import filter_generated_texts, score_gen_ppl, summarize_token_diversity
18
+ from flowtext_lab.tokenization import BpeTextTokenizer
19
+ from scripts.flowtext_decode_lab import DecodeConfig, decode_batch
20
+
21
+
22
+ def parse_args() -> argparse.Namespace:
23
+ p = argparse.ArgumentParser()
24
+ p.add_argument("--checkpoint", required=True)
25
+ p.add_argument("--tokenizer_path", required=True)
26
+ p.add_argument("--max_len", type=int, default=128)
27
+ p.add_argument("--samples", type=int, default=128)
28
+ p.add_argument("--batch_size", type=int, default=4)
29
+ p.add_argument("--steps", type=int, default=128)
30
+ p.add_argument("--seed", type=int, default=20260502)
31
+ p.add_argument("--noise_init", choices=["uniform", "logistic_normal", "dirichlet"], default="dirichlet")
32
+ p.add_argument("--dirichlet_init_concentration", type=float, default=1.0)
33
+ p.add_argument("--target_prob", type=float, default=1.0)
34
+ p.add_argument("--noise_sigma", type=float, default=-1.0)
35
+ p.add_argument("--eps", type=float, default=1e-8)
36
+ p.add_argument("--rule", choices=["flowmap", "replace", "geometric", "centered_residual"], default="flowmap")
37
+ p.add_argument("--model_t_mode", choices=["linear", "flow", "const0", "const05", "const1", "random"], default="flow")
38
+ p.add_argument("--damping", type=float, default=1.0)
39
+ p.add_argument("--max_gamma", type=float, default=1.0)
40
+ p.add_argument("--eta", type=float, default=0.5)
41
+ p.add_argument("--endpoint_temp", type=float, default=1.0)
42
+ p.add_argument("--final_from", choices=["state", "endpoint", "blend"], default="state")
43
+ p.add_argument("--state_floor", type=float, default=1e-8)
44
+ p.add_argument("--eos_logit_bias", type=float, default=0.0)
45
+ p.add_argument("--output", required=True)
46
+ p.add_argument("--gen_ppl_model", required=True)
47
+ p.add_argument("--gen_ppl_batch_size", type=int, default=4)
48
+ p.add_argument("--gen_ppl_max_length", type=int, default=1024)
49
+ return p.parse_args()
50
+
51
+
52
+ @torch.no_grad()
53
+ def main() -> None:
54
+ args = parse_args()
55
+ torch.manual_seed(args.seed)
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path)
58
+ ckpt = torch.load(args.checkpoint, map_location="cpu")
59
+ model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device)
60
+ model.eval()
61
+
62
+ cfg = DecodeConfig(
63
+ label=f"{args.rule}_t{args.endpoint_temp:g}_{args.final_from}",
64
+ rule=args.rule,
65
+ steps=args.steps,
66
+ model_t_mode=args.model_t_mode,
67
+ eta=args.eta,
68
+ damping=args.damping,
69
+ max_gamma=args.max_gamma,
70
+ endpoint_temp=args.endpoint_temp,
71
+ state_floor=args.state_floor,
72
+ final_from=args.final_from,
73
+ eos_logit_bias=args.eos_logit_bias,
74
+ )
75
+
76
+ output = Path(args.output)
77
+ if not output.is_absolute():
78
+ output = Path(args.checkpoint).resolve().parent / output
79
+ output.parent.mkdir(parents=True, exist_ok=True)
80
+
81
+ generated_texts: list[str] = []
82
+ generated_ids: list[list[int]] = []
83
+ remaining = args.samples
84
+ with output.open("w", encoding="utf-8") as f:
85
+ sample_index = 0
86
+ while remaining > 0:
87
+ cur_bs = min(args.batch_size, remaining)
88
+ init = sample_noise_simplex(
89
+ (cur_bs, args.max_len),
90
+ tokenizer.vocab_size,
91
+ device,
92
+ args.eps,
93
+ noise_mode=args.noise_init,
94
+ target_prob=args.target_prob,
95
+ noise_sigma=args.noise_sigma,
96
+ dirichlet_concentration=args.dirichlet_init_concentration,
97
+ )
98
+ attn = torch.ones((cur_bs, args.max_len), dtype=torch.bool, device=device)
99
+ lock = torch.zeros((cur_bs, args.max_len), dtype=torch.bool, device=device)
100
+ lock_probs = torch.zeros((cur_bs, args.max_len, tokenizer.vocab_size), dtype=torch.float32, device=device)
101
+ decoded = decode_batch(model, init, attn, lock, lock_probs, cfg, args.eps, tokenizer.eos_id)
102
+ ids_batch = decoded.argmax(dim=-1).detach().cpu().tolist()
103
+ for ids in ids_batch:
104
+ text = tokenizer.decode(ids, stop_at_eos=False, skip_special_tokens=False)
105
+ generated_ids.append(ids)
106
+ generated_texts.append(text)
107
+ f.write(json.dumps({"index": sample_index, "text": text, "ids": ids}, ensure_ascii=False) + "\n")
108
+ sample_index += 1
109
+ remaining -= cur_bs
110
+
111
+ kept_texts, keep_mask = filter_generated_texts(generated_texts, drop_empty=True)
112
+ diversity = summarize_token_diversity(generated_ids)
113
+ gen_result = score_gen_ppl(
114
+ kept_texts,
115
+ model_name_or_path=args.gen_ppl_model,
116
+ batch_size=args.gen_ppl_batch_size,
117
+ max_length=args.gen_ppl_max_length,
118
+ device=device,
119
+ drop_remainder=False,
120
+ )
121
+ print(json.dumps(
122
+ {
123
+ "checkpoint": args.checkpoint,
124
+ "output": str(output),
125
+ "samples": args.samples,
126
+ "kept_samples": len(kept_texts),
127
+ "decode": vars(args),
128
+ "gen_ppl": gen_result.ppl,
129
+ "gen_nll_per_token": gen_result.nll_per_token,
130
+ "gen_tokens": gen_result.tokens,
131
+ "gen_scored_samples": gen_result.kept_samples - gen_result.skipped_samples,
132
+ "sample_entropy": diversity.sample_entropy,
133
+ "distinct_1": diversity.distinct_1,
134
+ "distinct_2": diversity.distinct_2,
135
+ "top_token_mass": diversity.top_token_mass,
136
+ "preview": kept_texts[:5],
137
+ },
138
+ ensure_ascii=False,
139
+ indent=2,
140
+ ))
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()
LTA_openwebtext_dualt/scripts/watch_infer_lm1b_classic_c1024_every1k_t1p45.sh ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
7
+ export TOKENIZERS_PARALLELISM=false
8
+ export PYTHONUNBUFFERED=1
9
+
10
+ RUN_DIR="${RUN_DIR:?RUN_DIR is required, e.g. runs/<run_name>}"
11
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/workspace/imagenet_handoff_20260327/nlp_dts_light/assets/distilbert-base-uncased/tokenizer.json}"
12
+ SCORER="${SCORER:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-large-standard}"
13
+
14
+ CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-3}"
15
+ N_SAMPLES="${N_SAMPLES:-256}"
16
+ STEPS="${STEPS:-128}"
17
+ CMAX="${CMAX:-1024}"
18
+ TEMP="${TEMP:-1.45}"
19
+ MAX_LEN="${MAX_LEN:-128}"
20
+ DECODE_BATCH="${DECODE_BATCH:-16}"
21
+ SCORE_BATCH="${SCORE_BATCH:-8}"
22
+ SCORE_MAX_LENGTH="${SCORE_MAX_LENGTH:-256}"
23
+ STEP_INTERVAL="${STEP_INTERVAL:-1000}"
24
+ SLEEP_SECONDS="${SLEEP_SECONDS:-30}"
25
+ DATE_TAG="${DATE_TAG:-$(date +%Y%m%d)}"
26
+
27
+ RUN_STEM="$(basename "${RUN_DIR}")"
28
+ TEMP_TAG="${TEMP//./p}"
29
+ OUT_BASE="${OUT_BASE:-docs/lta_samples/metrics_${DATE_TAG}/lm1b_classic_dirichlet_len${MAX_LEN}_every1k_normal_steps_state_t${TEMP_TAG}_c${CMAX}_n${N_SAMPLES}/${RUN_STEM}}"
30
+ LOG_DIR="${LOG_DIR:-logs/lm1b_classic_dirichlet_every1k_infer_watch}"
31
+ PROCESSED_FILE="${PROCESSED_FILE:-${LOG_DIR}/processed_${RUN_STEM}_steps${STEPS}_c${CMAX}_t${TEMP_TAG}_n${N_SAMPLES}.txt}"
32
+
33
+ mkdir -p "${OUT_BASE}" "${LOG_DIR}"
34
+ touch "${PROCESSED_FILE}"
35
+
36
+ echo "[watch-classic-1k] run_dir=${RUN_DIR}"
37
+ echo "[watch-classic-1k] out_base=${OUT_BASE}"
38
+ echo "[watch-classic-1k] processed_file=${PROCESSED_FILE}"
39
+ echo "[watch-classic-1k] interval=${STEP_INTERVAL} decode=normal_steps_sweep steps=${STEPS} cmax=${CMAX} temp=${TEMP} final_from=state n=${N_SAMPLES}"
40
+
41
+ while true; do
42
+ shopt -s nullglob
43
+ ckpts=("${RUN_DIR}"/step_*.pt)
44
+ shopt -u nullglob
45
+
46
+ if (( ${#ckpts[@]} == 0 )); then
47
+ echo "[watch-classic-1k] $(date +%F_%T) no step_*.pt yet"
48
+ sleep "${SLEEP_SECONDS}"
49
+ continue
50
+ fi
51
+
52
+ printf '%s\n' "${ckpts[@]}" | sort | while read -r ckpt; do
53
+ base="$(basename "${ckpt}")"
54
+ step="${base#step_}"
55
+ step="${step%.pt}"
56
+ step_num="$((10#${step}))"
57
+ if (( step_num % STEP_INTERVAL != 0 )); then
58
+ continue
59
+ fi
60
+ if grep -Fxq "${ckpt}" "${PROCESSED_FILE}"; then
61
+ continue
62
+ fi
63
+
64
+ out_dir="${OUT_BASE}/step_${step}"
65
+ log_file="${LOG_DIR}/infer_${RUN_STEM}_step_${step}_t${TEMP_TAG}.log"
66
+ mkdir -p "${out_dir}"
67
+
68
+ echo "[watch-classic-1k] $(date +%F_%T) infer ${ckpt} -> ${out_dir}" | tee -a "${log_file}"
69
+ CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES}" python scripts/eval_owt_normal_steps_sweep_20260515.py \
70
+ --checkpoint "${ckpt}" \
71
+ --tokenizer_path "${TOKENIZER_PATH}" \
72
+ --scorer "${SCORER}" \
73
+ --out_dir "${out_dir}" \
74
+ --steps_list "${STEPS}" \
75
+ --cmax_list "${CMAX}" \
76
+ --endpoint_temps "${TEMP}" \
77
+ --n_samples "${N_SAMPLES}" \
78
+ --max_len "${MAX_LEN}" \
79
+ --decode_batch "${DECODE_BATCH}" \
80
+ --score_batch "${SCORE_BATCH}" \
81
+ --score_max_length "${SCORE_MAX_LENGTH}" \
82
+ --detokenizer lm1b \
83
+ --seed 20260511 \
84
+ --save_samples 16 \
85
+ 2>&1 | tee -a "${log_file}"
86
+
87
+ echo "${ckpt}" >> "${PROCESSED_FILE}"
88
+ echo "[watch-classic-1k] $(date +%F_%T) done step_${step}" | tee -a "${log_file}"
89
+ done
90
+
91
+ sleep "${SLEEP_SECONDS}"
92
+ done