JinghuiLuAstronaut commited on
Commit
8ed0c93
·
verified ·
1 Parent(s): 9805aea

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. LTA_openwebtext_dualt/logs/lta_lm1b_classic_dirichlet_len512_gbs512_4gpu_10k_save1k_20260523.train.pid +1 -0
  2. LTA_openwebtext_dualt/logs/noise_geometry_combo_4gpu/20260517_170456.log +994 -0
  3. LTA_openwebtext_dualt/logs/train8_len_sweep_compact_bs512_until_exact_4gpu/driver.log +0 -0
  4. LTA_openwebtext_dualt/scripts/apple_to_apple_lta_checks.py +631 -0
  5. LTA_openwebtext_dualt/scripts/build_lta_owt_compact_gpt2bpe_stream1024_train_minus_100k_np8.sh +13 -0
  6. LTA_openwebtext_dualt/scripts/build_owt_t5_elf_dataset.py +587 -0
  7. LTA_openwebtext_dualt/scripts/eval_dirichlet_latest_key3_state_20260508.py +51 -0
  8. LTA_openwebtext_dualt/scripts/infer_lta_owt_t5_len128_uniform10k_then_lognsr_latest.sh +113 -0
  9. LTA_openwebtext_dualt/scripts/launch_lta_lm1b_categorical_fullvocab_c1024_fullycoupled_8gpu_small_1m.sh +150 -0
  10. LTA_openwebtext_dualt/scripts/launch_lta_lm1b_categorical_fullvocab_c16_dualt_4gpu_small_1m.sh +155 -0
  11. LTA_openwebtext_dualt/scripts/launch_lta_owt_c1024_fullycoupled_8gpu_len1024_gpt2_cached_chunks_1m.sh +60 -0
  12. LTA_openwebtext_dualt/scripts/launch_lta_owt_compact_gpt2bpe_v8192_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu.sh +39 -0
  13. LTA_openwebtext_dualt/scripts/launch_lta_owt_elfaligned_t5_logitnormal_8gpu.sh +209 -0
  14. LTA_openwebtext_dualt/scripts/launch_lta_owt_fullycoupled_outwd0p5_8gpu.sh +11 -0
  15. LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_rollin_grad_k1_rho025_subset10k_4gpu_100k.sh +148 -0
  16. LTA_openwebtext_dualt/scripts/run_lta_lm1b_dirichlet_len1024_Cv_to_2v_8gpu_1m_save10k.sh +34 -0
  17. LTA_openwebtext_dualt/scripts/run_lta_owt_dirichlet_len1024_Cv_to_2v_8gpu_1m_save10k.sh +34 -0
  18. LTA_openwebtext_dualt/scripts/run_lta_owt_t5_absrope_adaln_dirichlet_len1024_Cv_to_2v_8gpu_mask0p1_1p0_sameT_1m_save10k.sh +36 -0
  19. LTA_openwebtext_dualt/scripts/run_train8_wrong_floor_pilots_4gpu.sh +194 -0
  20. LTA_openwebtext_dualt/scripts/watch_infer_owt_classic_fullvocab_len1024_lr2e4_gbs2048_latest_every1k_t1p45.sh +158 -0
LTA_openwebtext_dualt/logs/lta_lm1b_classic_dirichlet_len512_gbs512_4gpu_10k_save1k_20260523.train.pid ADDED
@@ -0,0 +1 @@
 
 
1
+ 993819
LTA_openwebtext_dualt/logs/noise_geometry_combo_4gpu/20260517_170456.log ADDED
@@ -0,0 +1,994 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [combo-pilot] start stamp=20260517_170456 len=256 vocab=969 out=docs/lta_samples/metrics_20260517/noise_geometry_combo_len256_bs512_ode128_20260517_170456
2
+ [combo-pilot] round=1 Sun May 17 17:04:56 UTC 2026
3
+ [combo-pilot] train config=logistic_unigram_shared_highC from=0 to=1000 sampler=logistic_normal_linear_mean C=64->4096 unigram_shared=0.5 seq=0.0
4
+ [combo-pilot] eval config=logistic_unigram_shared_highC step=1000
5
+ [eval-decode-acc] train8_combo_len256_logistic_unigram_shared_highC_20260517_170456 step=1000 soft=none
6
+ [decode] max_len=256 generated=64/64
7
+ {
8
+ "num_rows": 1,
9
+ "best_by_run": {
10
+ "train8_combo_len256_logistic_unigram_shared_highC_20260517_170456::none": {
11
+ "run": "train8_combo_len256_logistic_unigram_shared_highC_20260517_170456",
12
+ "checkpoint": "runs/train8_combo_len256_logistic_unigram_shared_highC_20260517_170456/step_0001000.pt",
13
+ "ckpt_step": 1000,
14
+ "endpoint_softening": "none",
15
+ "decode_rule": "flowmap",
16
+ "steps": 128,
17
+ "time_schedule": "logit_normal",
18
+ "model_t_mode": "post",
19
+ "final_from": "state",
20
+ "n_gen": 64,
21
+ "n_refs": 8,
22
+ "token_acc_mean": 0.0487060546875,
23
+ "token_acc_min": 0.03515625,
24
+ "token_acc_max": 0.07421875,
25
+ "exact_acc": 0.0,
26
+ "exact_count": 0,
27
+ "exact_ref_coverage": 0.0,
28
+ "exact_ref_count": 0,
29
+ "exact_ref_hits": [],
30
+ "best_ref_idx": [
31
+ 5,
32
+ 0,
33
+ 0,
34
+ 0,
35
+ 5,
36
+ 5,
37
+ 5,
38
+ 0,
39
+ 5,
40
+ 2,
41
+ 1,
42
+ 0,
43
+ 7,
44
+ 2,
45
+ 7,
46
+ 0,
47
+ 3,
48
+ 3,
49
+ 2,
50
+ 0,
51
+ 2,
52
+ 2,
53
+ 5,
54
+ 7,
55
+ 5,
56
+ 7,
57
+ 7,
58
+ 2,
59
+ 5,
60
+ 7,
61
+ 5,
62
+ 2,
63
+ 1,
64
+ 5,
65
+ 0,
66
+ 0,
67
+ 5,
68
+ 2,
69
+ 0,
70
+ 0,
71
+ 2,
72
+ 0,
73
+ 0,
74
+ 5,
75
+ 5,
76
+ 3,
77
+ 5,
78
+ 5,
79
+ 5,
80
+ 3,
81
+ 3,
82
+ 0,
83
+ 3,
84
+ 2,
85
+ 5,
86
+ 0,
87
+ 7,
88
+ 0,
89
+ 1,
90
+ 5,
91
+ 2,
92
+ 7,
93
+ 3,
94
+ 2
95
+ ],
96
+ "best_token_acc": [
97
+ 0.04296875,
98
+ 0.04296875,
99
+ 0.04296875,
100
+ 0.046875,
101
+ 0.05859375,
102
+ 0.04296875,
103
+ 0.04296875,
104
+ 0.05859375,
105
+ 0.046875,
106
+ 0.05859375,
107
+ 0.04296875,
108
+ 0.05859375,
109
+ 0.0390625,
110
+ 0.046875,
111
+ 0.0625,
112
+ 0.0390625,
113
+ 0.04296875,
114
+ 0.046875,
115
+ 0.046875,
116
+ 0.046875,
117
+ 0.05078125,
118
+ 0.05078125,
119
+ 0.04296875,
120
+ 0.0546875,
121
+ 0.046875,
122
+ 0.046875,
123
+ 0.046875,
124
+ 0.046875,
125
+ 0.0625,
126
+ 0.0625,
127
+ 0.05078125,
128
+ 0.0390625,
129
+ 0.0546875,
130
+ 0.046875,
131
+ 0.04296875,
132
+ 0.0390625,
133
+ 0.05078125,
134
+ 0.0390625,
135
+ 0.046875,
136
+ 0.04296875,
137
+ 0.03515625,
138
+ 0.046875,
139
+ 0.046875,
140
+ 0.0546875,
141
+ 0.0546875,
142
+ 0.04296875,
143
+ 0.04296875,
144
+ 0.0546875,
145
+ 0.04296875,
146
+ 0.046875,
147
+ 0.05078125,
148
+ 0.07421875,
149
+ 0.04296875,
150
+ 0.05078125,
151
+ 0.046875,
152
+ 0.0546875,
153
+ 0.0546875,
154
+ 0.04296875,
155
+ 0.0546875,
156
+ 0.0546875,
157
+ 0.0546875,
158
+ 0.05078125,
159
+ 0.04296875,
160
+ 0.05078125
161
+ ]
162
+ }
163
+ },
164
+ "first_exact_by_run": {}
165
+ }
166
+ RESULT config=logistic_unigram_shared_highC ckpt_step=1000 views=512000 token_acc=0.0487 exact=0/64 exact_refs=0 hits=[]
167
+ [combo-pilot] continue config=logistic_unigram_shared_highC step=1000
168
+ [combo-pilot] train config=logistic_unigram_shared_highC_seqrand from=0 to=1000 sampler=logistic_normal_linear_mean C=64->4096 unigram_shared=0.5 seq=0.5
169
+ [combo-pilot] eval config=logistic_unigram_shared_highC_seqrand step=1000
170
+ [eval-decode-acc] train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456 step=1000 soft=none
171
+ [decode] max_len=256 generated=64/64
172
+ {
173
+ "num_rows": 1,
174
+ "best_by_run": {
175
+ "train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456::none": {
176
+ "run": "train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456",
177
+ "checkpoint": "runs/train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456/step_0001000.pt",
178
+ "ckpt_step": 1000,
179
+ "endpoint_softening": "none",
180
+ "decode_rule": "flowmap",
181
+ "steps": 128,
182
+ "time_schedule": "logit_normal",
183
+ "model_t_mode": "post",
184
+ "final_from": "state",
185
+ "n_gen": 64,
186
+ "n_refs": 8,
187
+ "token_acc_mean": 0.04034423828125,
188
+ "token_acc_min": 0.0234375,
189
+ "token_acc_max": 0.0625,
190
+ "exact_acc": 0.0,
191
+ "exact_count": 0,
192
+ "exact_ref_coverage": 0.0,
193
+ "exact_ref_count": 0,
194
+ "exact_ref_hits": [],
195
+ "best_ref_idx": [
196
+ 0,
197
+ 0,
198
+ 0,
199
+ 0,
200
+ 0,
201
+ 3,
202
+ 0,
203
+ 7,
204
+ 0,
205
+ 4,
206
+ 0,
207
+ 5,
208
+ 4,
209
+ 0,
210
+ 0,
211
+ 0,
212
+ 3,
213
+ 0,
214
+ 0,
215
+ 0,
216
+ 3,
217
+ 0,
218
+ 0,
219
+ 0,
220
+ 4,
221
+ 0,
222
+ 0,
223
+ 5,
224
+ 0,
225
+ 4,
226
+ 0,
227
+ 0,
228
+ 0,
229
+ 0,
230
+ 5,
231
+ 0,
232
+ 0,
233
+ 0,
234
+ 0,
235
+ 4,
236
+ 0,
237
+ 0,
238
+ 0,
239
+ 5,
240
+ 3,
241
+ 0,
242
+ 0,
243
+ 0,
244
+ 0,
245
+ 4,
246
+ 0,
247
+ 4,
248
+ 0,
249
+ 0,
250
+ 0,
251
+ 0,
252
+ 5,
253
+ 0,
254
+ 0,
255
+ 0,
256
+ 4,
257
+ 0,
258
+ 3,
259
+ 0
260
+ ],
261
+ "best_token_acc": [
262
+ 0.03515625,
263
+ 0.03515625,
264
+ 0.03125,
265
+ 0.05859375,
266
+ 0.03515625,
267
+ 0.0234375,
268
+ 0.03515625,
269
+ 0.02734375,
270
+ 0.0625,
271
+ 0.03515625,
272
+ 0.02734375,
273
+ 0.03125,
274
+ 0.0234375,
275
+ 0.03515625,
276
+ 0.046875,
277
+ 0.04296875,
278
+ 0.05078125,
279
+ 0.03125,
280
+ 0.03515625,
281
+ 0.0625,
282
+ 0.03125,
283
+ 0.04296875,
284
+ 0.02734375,
285
+ 0.04296875,
286
+ 0.03125,
287
+ 0.0390625,
288
+ 0.05078125,
289
+ 0.0390625,
290
+ 0.02734375,
291
+ 0.03125,
292
+ 0.03125,
293
+ 0.0234375,
294
+ 0.046875,
295
+ 0.05078125,
296
+ 0.04296875,
297
+ 0.03515625,
298
+ 0.05078125,
299
+ 0.04296875,
300
+ 0.0390625,
301
+ 0.05078125,
302
+ 0.0390625,
303
+ 0.046875,
304
+ 0.0390625,
305
+ 0.0390625,
306
+ 0.02734375,
307
+ 0.05078125,
308
+ 0.05078125,
309
+ 0.046875,
310
+ 0.04296875,
311
+ 0.046875,
312
+ 0.05859375,
313
+ 0.05859375,
314
+ 0.04296875,
315
+ 0.05078125,
316
+ 0.05078125,
317
+ 0.046875,
318
+ 0.03125,
319
+ 0.04296875,
320
+ 0.0390625,
321
+ 0.05078125,
322
+ 0.03125,
323
+ 0.03125,
324
+ 0.03515625,
325
+ 0.0390625
326
+ ]
327
+ }
328
+ },
329
+ "first_exact_by_run": {}
330
+ }
331
+ RESULT config=logistic_unigram_shared_highC_seqrand ckpt_step=1000 views=512000 token_acc=0.0403 exact=0/64 exact_refs=0 hits=[]
332
+ [combo-pilot] continue config=logistic_unigram_shared_highC_seqrand step=1000
333
+ [combo-pilot] train config=logistic_unigram_shared_C1024 from=0 to=1000 sampler=logistic_normal_linear_mean C=1.0->1024 unigram_shared=0.5 seq=0.0
334
+ [combo-pilot] eval config=logistic_unigram_shared_C1024 step=1000
335
+ [eval-decode-acc] train8_combo_len256_logistic_unigram_shared_C1024_20260517_170456 step=1000 soft=none
336
+ [decode] max_len=256 generated=64/64
337
+ {
338
+ "num_rows": 1,
339
+ "best_by_run": {
340
+ "train8_combo_len256_logistic_unigram_shared_C1024_20260517_170456::none": {
341
+ "run": "train8_combo_len256_logistic_unigram_shared_C1024_20260517_170456",
342
+ "checkpoint": "runs/train8_combo_len256_logistic_unigram_shared_C1024_20260517_170456/step_0001000.pt",
343
+ "ckpt_step": 1000,
344
+ "endpoint_softening": "none",
345
+ "decode_rule": "flowmap",
346
+ "steps": 128,
347
+ "time_schedule": "logit_normal",
348
+ "model_t_mode": "post",
349
+ "final_from": "state",
350
+ "n_gen": 64,
351
+ "n_refs": 8,
352
+ "token_acc_mean": 0.0487060546875,
353
+ "token_acc_min": 0.03515625,
354
+ "token_acc_max": 0.07421875,
355
+ "exact_acc": 0.0,
356
+ "exact_count": 0,
357
+ "exact_ref_coverage": 0.0,
358
+ "exact_ref_count": 0,
359
+ "exact_ref_hits": [],
360
+ "best_ref_idx": [
361
+ 5,
362
+ 0,
363
+ 0,
364
+ 0,
365
+ 5,
366
+ 5,
367
+ 5,
368
+ 0,
369
+ 5,
370
+ 2,
371
+ 1,
372
+ 0,
373
+ 7,
374
+ 2,
375
+ 7,
376
+ 0,
377
+ 3,
378
+ 3,
379
+ 2,
380
+ 0,
381
+ 2,
382
+ 2,
383
+ 5,
384
+ 7,
385
+ 5,
386
+ 7,
387
+ 7,
388
+ 2,
389
+ 5,
390
+ 7,
391
+ 5,
392
+ 2,
393
+ 1,
394
+ 5,
395
+ 0,
396
+ 0,
397
+ 5,
398
+ 2,
399
+ 0,
400
+ 0,
401
+ 2,
402
+ 0,
403
+ 0,
404
+ 5,
405
+ 5,
406
+ 3,
407
+ 5,
408
+ 5,
409
+ 5,
410
+ 3,
411
+ 3,
412
+ 0,
413
+ 3,
414
+ 2,
415
+ 5,
416
+ 0,
417
+ 7,
418
+ 0,
419
+ 1,
420
+ 5,
421
+ 2,
422
+ 7,
423
+ 3,
424
+ 2
425
+ ],
426
+ "best_token_acc": [
427
+ 0.04296875,
428
+ 0.04296875,
429
+ 0.04296875,
430
+ 0.046875,
431
+ 0.05859375,
432
+ 0.04296875,
433
+ 0.04296875,
434
+ 0.05859375,
435
+ 0.046875,
436
+ 0.05859375,
437
+ 0.04296875,
438
+ 0.05859375,
439
+ 0.0390625,
440
+ 0.046875,
441
+ 0.0625,
442
+ 0.0390625,
443
+ 0.04296875,
444
+ 0.046875,
445
+ 0.046875,
446
+ 0.046875,
447
+ 0.05078125,
448
+ 0.05078125,
449
+ 0.04296875,
450
+ 0.0546875,
451
+ 0.046875,
452
+ 0.046875,
453
+ 0.046875,
454
+ 0.046875,
455
+ 0.0625,
456
+ 0.0625,
457
+ 0.05078125,
458
+ 0.0390625,
459
+ 0.0546875,
460
+ 0.046875,
461
+ 0.04296875,
462
+ 0.0390625,
463
+ 0.05078125,
464
+ 0.0390625,
465
+ 0.046875,
466
+ 0.04296875,
467
+ 0.03515625,
468
+ 0.046875,
469
+ 0.046875,
470
+ 0.0546875,
471
+ 0.0546875,
472
+ 0.04296875,
473
+ 0.04296875,
474
+ 0.0546875,
475
+ 0.04296875,
476
+ 0.046875,
477
+ 0.05078125,
478
+ 0.07421875,
479
+ 0.04296875,
480
+ 0.05078125,
481
+ 0.046875,
482
+ 0.0546875,
483
+ 0.0546875,
484
+ 0.04296875,
485
+ 0.0546875,
486
+ 0.0546875,
487
+ 0.0546875,
488
+ 0.05078125,
489
+ 0.04296875,
490
+ 0.05078125
491
+ ]
492
+ }
493
+ },
494
+ "first_exact_by_run": {}
495
+ }
496
+ RESULT config=logistic_unigram_shared_C1024 ckpt_step=1000 views=512000 token_acc=0.0487 exact=0/64 exact_refs=0 hits=[]
497
+ [combo-pilot] continue config=logistic_unigram_shared_C1024 step=1000
498
+ [combo-pilot] train config=dirichlet_unigram_shared_highC from=0 to=1000 sampler=dirichlet C=64->4096 unigram_shared=0.5 seq=0.0
499
+ [combo-pilot] eval config=dirichlet_unigram_shared_highC step=1000
500
+ [eval-decode-acc] train8_combo_len256_dirichlet_unigram_shared_highC_20260517_170456 step=1000 soft=none
501
+ [decode] max_len=256 generated=64/64
502
+ {
503
+ "num_rows": 1,
504
+ "best_by_run": {
505
+ "train8_combo_len256_dirichlet_unigram_shared_highC_20260517_170456::none": {
506
+ "run": "train8_combo_len256_dirichlet_unigram_shared_highC_20260517_170456",
507
+ "checkpoint": "runs/train8_combo_len256_dirichlet_unigram_shared_highC_20260517_170456/step_0001000.pt",
508
+ "ckpt_step": 1000,
509
+ "endpoint_softening": "none",
510
+ "decode_rule": "flowmap",
511
+ "steps": 128,
512
+ "time_schedule": "logit_normal",
513
+ "model_t_mode": "post",
514
+ "final_from": "state",
515
+ "n_gen": 64,
516
+ "n_refs": 8,
517
+ "token_acc_mean": 0.03857421875,
518
+ "token_acc_min": 0.02734375,
519
+ "token_acc_max": 0.05078125,
520
+ "exact_acc": 0.0,
521
+ "exact_count": 0,
522
+ "exact_ref_coverage": 0.0,
523
+ "exact_ref_count": 0,
524
+ "exact_ref_hits": [],
525
+ "best_ref_idx": [
526
+ 1,
527
+ 1,
528
+ 1,
529
+ 2,
530
+ 1,
531
+ 1,
532
+ 0,
533
+ 1,
534
+ 0,
535
+ 1,
536
+ 0,
537
+ 0,
538
+ 1,
539
+ 1,
540
+ 1,
541
+ 1,
542
+ 1,
543
+ 1,
544
+ 1,
545
+ 2,
546
+ 0,
547
+ 1,
548
+ 1,
549
+ 2,
550
+ 1,
551
+ 1,
552
+ 0,
553
+ 0,
554
+ 1,
555
+ 0,
556
+ 2,
557
+ 1,
558
+ 1,
559
+ 0,
560
+ 0,
561
+ 1,
562
+ 0,
563
+ 2,
564
+ 0,
565
+ 1,
566
+ 1,
567
+ 1,
568
+ 1,
569
+ 1,
570
+ 0,
571
+ 5,
572
+ 2,
573
+ 1,
574
+ 0,
575
+ 2,
576
+ 1,
577
+ 1,
578
+ 1,
579
+ 2,
580
+ 1,
581
+ 0,
582
+ 1,
583
+ 1,
584
+ 1,
585
+ 1,
586
+ 1,
587
+ 1,
588
+ 0,
589
+ 1
590
+ ],
591
+ "best_token_acc": [
592
+ 0.03125,
593
+ 0.04296875,
594
+ 0.046875,
595
+ 0.0390625,
596
+ 0.0390625,
597
+ 0.0390625,
598
+ 0.04296875,
599
+ 0.03515625,
600
+ 0.0390625,
601
+ 0.03515625,
602
+ 0.03125,
603
+ 0.02734375,
604
+ 0.03515625,
605
+ 0.03125,
606
+ 0.03515625,
607
+ 0.03515625,
608
+ 0.03515625,
609
+ 0.04296875,
610
+ 0.04296875,
611
+ 0.03125,
612
+ 0.02734375,
613
+ 0.03125,
614
+ 0.04296875,
615
+ 0.0390625,
616
+ 0.0390625,
617
+ 0.03515625,
618
+ 0.03515625,
619
+ 0.0390625,
620
+ 0.046875,
621
+ 0.03515625,
622
+ 0.05078125,
623
+ 0.0390625,
624
+ 0.046875,
625
+ 0.04296875,
626
+ 0.0390625,
627
+ 0.0390625,
628
+ 0.0390625,
629
+ 0.04296875,
630
+ 0.03125,
631
+ 0.046875,
632
+ 0.03515625,
633
+ 0.046875,
634
+ 0.046875,
635
+ 0.04296875,
636
+ 0.03125,
637
+ 0.03515625,
638
+ 0.03515625,
639
+ 0.0390625,
640
+ 0.03125,
641
+ 0.046875,
642
+ 0.0390625,
643
+ 0.05078125,
644
+ 0.0390625,
645
+ 0.02734375,
646
+ 0.02734375,
647
+ 0.0390625,
648
+ 0.05078125,
649
+ 0.03125,
650
+ 0.03515625,
651
+ 0.04296875,
652
+ 0.0390625,
653
+ 0.04296875,
654
+ 0.0390625,
655
+ 0.046875
656
+ ]
657
+ }
658
+ },
659
+ "first_exact_by_run": {}
660
+ }
661
+ RESULT config=dirichlet_unigram_shared_highC ckpt_step=1000 views=512000 token_acc=0.0386 exact=0/64 exact_refs=0 hits=[]
662
+ [combo-pilot] continue config=dirichlet_unigram_shared_highC step=1000
663
+ [combo-pilot] round=2 Sun May 17 17:08:26 UTC 2026
664
+ [combo-pilot] train config=logistic_unigram_shared_highC from=1000 to=2000 sampler=logistic_normal_linear_mean C=64->4096 unigram_shared=0.5 seq=0.0
665
+ [combo-pilot] eval config=logistic_unigram_shared_highC step=2000
666
+ [eval-decode-acc] train8_combo_len256_logistic_unigram_shared_highC_20260517_170456 step=2000 soft=none
667
+ [decode] max_len=256 generated=64/64
668
+ {
669
+ "num_rows": 1,
670
+ "best_by_run": {
671
+ "train8_combo_len256_logistic_unigram_shared_highC_20260517_170456::none": {
672
+ "run": "train8_combo_len256_logistic_unigram_shared_highC_20260517_170456",
673
+ "checkpoint": "runs/train8_combo_len256_logistic_unigram_shared_highC_20260517_170456/step_0002000.pt",
674
+ "ckpt_step": 2000,
675
+ "endpoint_softening": "none",
676
+ "decode_rule": "flowmap",
677
+ "steps": 128,
678
+ "time_schedule": "logit_normal",
679
+ "model_t_mode": "post",
680
+ "final_from": "state",
681
+ "n_gen": 64,
682
+ "n_refs": 8,
683
+ "token_acc_mean": 0.03033447265625,
684
+ "token_acc_min": 0.015625,
685
+ "token_acc_max": 0.046875,
686
+ "exact_acc": 0.0,
687
+ "exact_count": 0,
688
+ "exact_ref_coverage": 0.0,
689
+ "exact_ref_count": 0,
690
+ "exact_ref_hits": [],
691
+ "best_ref_idx": [
692
+ 1,
693
+ 1,
694
+ 1,
695
+ 1,
696
+ 7,
697
+ 3,
698
+ 1,
699
+ 7,
700
+ 0,
701
+ 0,
702
+ 0,
703
+ 3,
704
+ 1,
705
+ 0,
706
+ 1,
707
+ 5,
708
+ 0,
709
+ 0,
710
+ 3,
711
+ 0,
712
+ 0,
713
+ 1,
714
+ 0,
715
+ 7,
716
+ 7,
717
+ 1,
718
+ 7,
719
+ 0,
720
+ 1,
721
+ 0,
722
+ 7,
723
+ 1,
724
+ 0,
725
+ 0,
726
+ 0,
727
+ 0,
728
+ 0,
729
+ 3,
730
+ 1,
731
+ 0,
732
+ 0,
733
+ 1,
734
+ 7,
735
+ 5,
736
+ 1,
737
+ 0,
738
+ 1,
739
+ 1,
740
+ 1,
741
+ 0,
742
+ 0,
743
+ 0,
744
+ 0,
745
+ 0,
746
+ 3,
747
+ 0,
748
+ 1,
749
+ 7,
750
+ 7,
751
+ 0,
752
+ 7,
753
+ 0,
754
+ 7,
755
+ 5
756
+ ],
757
+ "best_token_acc": [
758
+ 0.03125,
759
+ 0.02734375,
760
+ 0.02734375,
761
+ 0.03515625,
762
+ 0.046875,
763
+ 0.02734375,
764
+ 0.03125,
765
+ 0.04296875,
766
+ 0.04296875,
767
+ 0.02734375,
768
+ 0.046875,
769
+ 0.03515625,
770
+ 0.02734375,
771
+ 0.0234375,
772
+ 0.01953125,
773
+ 0.02734375,
774
+ 0.02734375,
775
+ 0.0390625,
776
+ 0.02734375,
777
+ 0.01953125,
778
+ 0.03125,
779
+ 0.03125,
780
+ 0.01953125,
781
+ 0.0390625,
782
+ 0.0234375,
783
+ 0.03125,
784
+ 0.02734375,
785
+ 0.02734375,
786
+ 0.03125,
787
+ 0.03125,
788
+ 0.03125,
789
+ 0.02734375,
790
+ 0.03125,
791
+ 0.03515625,
792
+ 0.03125,
793
+ 0.02734375,
794
+ 0.03515625,
795
+ 0.02734375,
796
+ 0.0234375,
797
+ 0.02734375,
798
+ 0.03125,
799
+ 0.03125,
800
+ 0.03515625,
801
+ 0.03515625,
802
+ 0.02734375,
803
+ 0.01953125,
804
+ 0.0234375,
805
+ 0.0234375,
806
+ 0.015625,
807
+ 0.046875,
808
+ 0.03125,
809
+ 0.02734375,
810
+ 0.03515625,
811
+ 0.0234375,
812
+ 0.03125,
813
+ 0.02734375,
814
+ 0.0234375,
815
+ 0.02734375,
816
+ 0.03125,
817
+ 0.03515625,
818
+ 0.03515625,
819
+ 0.03125,
820
+ 0.03125,
821
+ 0.0390625
822
+ ]
823
+ }
824
+ },
825
+ "first_exact_by_run": {}
826
+ }
827
+ RESULT config=logistic_unigram_shared_highC ckpt_step=2000 views=1024000 token_acc=0.0303 exact=0/64 exact_refs=0 hits=[]
828
+ [combo-pilot] continue config=logistic_unigram_shared_highC step=2000
829
+ [combo-pilot] train config=logistic_unigram_shared_highC_seqrand from=1000 to=2000 sampler=logistic_normal_linear_mean C=64->4096 unigram_shared=0.5 seq=0.5
830
+ [combo-pilot] eval config=logistic_unigram_shared_highC_seqrand step=2000
831
+ [eval-decode-acc] train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456 step=2000 soft=none
832
+ [decode] max_len=256 generated=64/64
833
+ {
834
+ "num_rows": 1,
835
+ "best_by_run": {
836
+ "train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456::none": {
837
+ "run": "train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456",
838
+ "checkpoint": "runs/train8_combo_len256_logistic_unigram_shared_highC_seqrand_20260517_170456/step_0002000.pt",
839
+ "ckpt_step": 2000,
840
+ "endpoint_softening": "none",
841
+ "decode_rule": "flowmap",
842
+ "steps": 128,
843
+ "time_schedule": "logit_normal",
844
+ "model_t_mode": "post",
845
+ "final_from": "state",
846
+ "n_gen": 64,
847
+ "n_refs": 8,
848
+ "token_acc_mean": 0.04046630859375,
849
+ "token_acc_min": 0.01953125,
850
+ "token_acc_max": 0.06640625,
851
+ "exact_acc": 0.0,
852
+ "exact_count": 0,
853
+ "exact_ref_coverage": 0.0,
854
+ "exact_ref_count": 0,
855
+ "exact_ref_hits": [],
856
+ "best_ref_idx": [
857
+ 0,
858
+ 7,
859
+ 0,
860
+ 0,
861
+ 0,
862
+ 0,
863
+ 7,
864
+ 0,
865
+ 7,
866
+ 7,
867
+ 7,
868
+ 0,
869
+ 7,
870
+ 7,
871
+ 7,
872
+ 7,
873
+ 7,
874
+ 7,
875
+ 7,
876
+ 0,
877
+ 0,
878
+ 7,
879
+ 0,
880
+ 0,
881
+ 0,
882
+ 7,
883
+ 0,
884
+ 7,
885
+ 0,
886
+ 0,
887
+ 0,
888
+ 0,
889
+ 7,
890
+ 1,
891
+ 0,
892
+ 7,
893
+ 0,
894
+ 0,
895
+ 5,
896
+ 0,
897
+ 0,
898
+ 7,
899
+ 0,
900
+ 0,
901
+ 0,
902
+ 7,
903
+ 5,
904
+ 0,
905
+ 5,
906
+ 2,
907
+ 0,
908
+ 0,
909
+ 0,
910
+ 7,
911
+ 0,
912
+ 7,
913
+ 1,
914
+ 0,
915
+ 0,
916
+ 0,
917
+ 7,
918
+ 2,
919
+ 0,
920
+ 0
921
+ ],
922
+ "best_token_acc": [
923
+ 0.01953125,
924
+ 0.03125,
925
+ 0.03515625,
926
+ 0.0546875,
927
+ 0.0390625,
928
+ 0.0546875,
929
+ 0.0234375,
930
+ 0.03125,
931
+ 0.046875,
932
+ 0.05078125,
933
+ 0.0390625,
934
+ 0.0234375,
935
+ 0.0390625,
936
+ 0.05859375,
937
+ 0.02734375,
938
+ 0.02734375,
939
+ 0.0546875,
940
+ 0.05078125,
941
+ 0.03515625,
942
+ 0.046875,
943
+ 0.05859375,
944
+ 0.02734375,
945
+ 0.046875,
946
+ 0.04296875,
947
+ 0.0546875,
948
+ 0.01953125,
949
+ 0.046875,
950
+ 0.03125,
951
+ 0.05078125,
952
+ 0.05859375,
953
+ 0.04296875,
954
+ 0.01953125,
955
+ 0.05078125,
956
+ 0.02734375,
957
+ 0.046875,
958
+ 0.03515625,
959
+ 0.03515625,
960
+ 0.05859375,
961
+ 0.03125,
962
+ 0.04296875,
963
+ 0.046875,
964
+ 0.05078125,
965
+ 0.04296875,
966
+ 0.0546875,
967
+ 0.02734375,
968
+ 0.02734375,
969
+ 0.02734375,
970
+ 0.046875,
971
+ 0.01953125,
972
+ 0.03515625,
973
+ 0.06640625,
974
+ 0.03515625,
975
+ 0.046875,
976
+ 0.046875,
977
+ 0.05078125,
978
+ 0.03125,
979
+ 0.03125,
980
+ 0.03125,
981
+ 0.03125,
982
+ 0.0546875,
983
+ 0.0546875,
984
+ 0.02734375,
985
+ 0.0546875,
986
+ 0.03125
987
+ ]
988
+ }
989
+ },
990
+ "first_exact_by_run": {}
991
+ }
992
+ RESULT config=logistic_unigram_shared_highC_seqrand ckpt_step=2000 views=1024000 token_acc=0.0405 exact=0/64 exact_refs=0 hits=[]
993
+ [combo-pilot] continue config=logistic_unigram_shared_highC_seqrand step=2000
994
+ [combo-pilot] train config=logistic_unigram_shared_C1024 from=1000 to=2000 sampler=logistic_normal_linear_mean C=1.0->1024 unigram_shared=0.5 seq=0.0
LTA_openwebtext_dualt/logs/train8_len_sweep_compact_bs512_until_exact_4gpu/driver.log ADDED
The diff for this file is too large to render. See raw diff
 
LTA_openwebtext_dualt/scripts/apple_to_apple_lta_checks.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import csv
6
+ import json
7
+ import math
8
+ import sys
9
+ from collections import Counter
10
+ from pathlib import Path
11
+ from typing import Any, Iterable
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch.utils.data import DataLoader
16
+
17
+
18
+ REPO_ROOT = Path(__file__).resolve().parents[1]
19
+ if str(REPO_ROOT) not in sys.path:
20
+ sys.path.insert(0, str(REPO_ROOT))
21
+
22
+ from eval import build_model_from_ckpt
23
+ from flowtext_lab.bridges import make_dirichlet_bridge_batch
24
+ from flowtext_lab.data import EosPadCollator, WrappedStreamingTextSequenceDataset, iter_text_records
25
+ from flowtext_lab.decode import sample_noise_simplex, state_for_model
26
+ from flowtext_lab.tokenization import BpeTextTokenizer
27
+ from train import TokenizedTextCollator, load_tokenized_hf_dataset
28
+
29
+
30
+ def token_piece(tok: BpeTextTokenizer, idx: int) -> str:
31
+ raw = getattr(tok, "tokenizer", None)
32
+ id_to_token = getattr(raw, "id_to_token", None)
33
+ if callable(id_to_token):
34
+ piece = id_to_token(int(idx))
35
+ if piece is not None:
36
+ return str(piece)
37
+ return tok.decode([int(idx)], stop_at_eos=False, skip_special_tokens=False)
38
+
39
+
40
+ def token_text(tok: BpeTextTokenizer, idx: int) -> str:
41
+ return tok.decode([int(idx)], stop_at_eos=False, skip_special_tokens=False)
42
+
43
+
44
+ def compact_piece(s: str) -> str:
45
+ return s.replace("\n", "\\n").replace("\t", "\\t")
46
+
47
+
48
+ def load_batch(
49
+ *,
50
+ data_path: str,
51
+ tokenizer: BpeTextTokenizer,
52
+ max_len: int,
53
+ batch_size: int,
54
+ mode: str,
55
+ text_column: str | None,
56
+ openwebtext_split: str,
57
+ wrap_mode: str,
58
+ max_records: int,
59
+ tokenized_pad_token: str,
60
+ ) -> dict[str, torch.Tensor]:
61
+ if mode == "tokenized_hf":
62
+ ds = load_tokenized_hf_dataset(data_path, max_records=max_records)
63
+ pad_id = tokenizer.pad_id if tokenized_pad_token == "pad" and tokenizer.pad_id is not None else tokenizer.eos_id
64
+ collate = TokenizedTextCollator(pad_id, max_len=max_len)
65
+ examples = [ds[i] for i in range(min(batch_size, len(ds)))]
66
+ return collate(examples)
67
+ if mode != "wrap":
68
+ raise ValueError(f"unknown data mode: {mode}")
69
+ ds = WrappedStreamingTextSequenceDataset(
70
+ data_path,
71
+ tokenizer,
72
+ max_len=max_len,
73
+ text_column=text_column,
74
+ openwebtext_split=openwebtext_split,
75
+ max_records_per_epoch=max_records,
76
+ wrap_mode=wrap_mode,
77
+ )
78
+ loader = DataLoader(ds, batch_size=batch_size, collate_fn=EosPadCollator(tokenizer.eos_id, max_len=max_len))
79
+ return next(iter(loader))
80
+
81
+
82
+ def iter_record_lengths(
83
+ *,
84
+ data_path: str,
85
+ tokenizer: BpeTextTokenizer,
86
+ mode: str,
87
+ text_column: str | None,
88
+ openwebtext_split: str,
89
+ max_records: int,
90
+ ) -> Iterable[int]:
91
+ if mode == "tokenized_hf":
92
+ ds = load_tokenized_hf_dataset(data_path, max_records=max_records)
93
+ for ex in ds:
94
+ raw = ex["input_ids"]
95
+ if hasattr(raw, "tolist"):
96
+ raw = raw.tolist()
97
+ yield len(raw)
98
+ return
99
+ for i, text in enumerate(
100
+ iter_text_records(
101
+ data_path,
102
+ text_column=text_column,
103
+ openwebtext_split=openwebtext_split,
104
+ detokenizer="auto",
105
+ )
106
+ ):
107
+ if i >= max_records:
108
+ break
109
+ ids = tokenizer.encode(text, add_eos=False, add_special_tokens=False)
110
+ yield len(ids)
111
+
112
+
113
+ def rate_summary(values: list[float]) -> dict[str, float]:
114
+ if not values:
115
+ return {"mean": 0.0, "min": 0.0, "p50": 0.0, "p90": 0.0, "p99": 0.0, "max": 0.0}
116
+ vals = sorted(float(x) for x in values)
117
+ n = len(vals)
118
+
119
+ def q(p: float) -> float:
120
+ return vals[min(n - 1, max(0, int(round((n - 1) * p))))]
121
+
122
+ return {
123
+ "mean": float(sum(vals) / n),
124
+ "min": float(vals[0]),
125
+ "p50": float(q(0.5)),
126
+ "p90": float(q(0.9)),
127
+ "p99": float(q(0.99)),
128
+ "max": float(vals[-1]),
129
+ }
130
+
131
+
132
+ def distribution_entropy_from_counts(counts: Counter[int]) -> float:
133
+ total = sum(counts.values())
134
+ if total <= 0:
135
+ return 0.0
136
+ out = 0.0
137
+ for c in counts.values():
138
+ p = c / total
139
+ out -= p * math.log(max(p, 1e-12))
140
+ return float(out)
141
+
142
+
143
+ def token_feature_rates(ids: torch.Tensor, tok: BpeTextTokenizer) -> dict[str, float]:
144
+ flat = [int(x) for x in ids.reshape(-1).tolist()]
145
+ if not flat:
146
+ return {}
147
+ pieces = [token_piece(tok, x) for x in flat]
148
+ texts = [token_text(tok, x) for x in flat]
149
+ specials = {tok.eos_id, tok.bos_id, tok.unk_id}
150
+ if tok.pad_id is not None:
151
+ specials.add(tok.pad_id)
152
+ denom = len(flat)
153
+ normal = [i for i, x in enumerate(flat) if x not in specials]
154
+ normal_denom = max(len(normal), 1)
155
+ return {
156
+ "bert_hash_rate": sum(pieces[i].startswith("##") for i in normal) / normal_denom,
157
+ "spm_cont_rate": sum((not pieces[i].startswith("▁")) and (not pieces[i].startswith("<")) for i in normal) / normal_denom,
158
+ "single_char_rate": sum(len(texts[i].strip()) == 1 for i in normal) / normal_denom,
159
+ "digit_piece_rate": sum(any(ch.isdigit() for ch in pieces[i]) for i in normal) / normal_denom,
160
+ "url_piece_rate": sum(("http" in pieces[i].lower() or "www" in pieces[i].lower() or ".com" in pieces[i].lower()) for i in normal) / normal_denom,
161
+ "special_rate": sum(x in specials for x in flat) / denom,
162
+ }
163
+
164
+
165
+ def command_data(args: argparse.Namespace) -> None:
166
+ tok = BpeTextTokenizer.from_file(args.tokenizer_path)
167
+ batch = load_batch(
168
+ data_path=args.data_path,
169
+ tokenizer=tok,
170
+ max_len=args.max_len,
171
+ batch_size=args.n_sequences,
172
+ mode=args.data_mode,
173
+ text_column=args.text_column,
174
+ openwebtext_split=args.openwebtext_split,
175
+ wrap_mode=args.wrap_mode,
176
+ max_records=args.max_records,
177
+ tokenized_pad_token=args.tokenized_pad_token,
178
+ )
179
+ ids = batch["ids"]
180
+ attn = batch.get("attn_mask", torch.ones_like(ids, dtype=torch.bool))
181
+ valid_ids = ids[attn]
182
+ counts = Counter(int(x) for x in valid_ids.tolist())
183
+ top = [
184
+ {
185
+ "id": int(i),
186
+ "piece": compact_piece(token_piece(tok, int(i))),
187
+ "text": compact_piece(token_text(tok, int(i))),
188
+ "count": int(c),
189
+ "rate": float(c / max(valid_ids.numel(), 1)),
190
+ }
191
+ for i, c in counts.most_common(args.top_k)
192
+ ]
193
+ seq_lens = attn.long().sum(dim=1).tolist()
194
+ internal = ids[:, 1:-1] if ids.size(1) > 2 else ids[:, :0]
195
+ internal_attn = attn[:, 1:-1] if attn.size(1) > 2 else attn[:, :0]
196
+ eos_internal = ((internal == tok.eos_id) & internal_attn).long().sum(dim=1).tolist()
197
+ pad_internal = []
198
+ if tok.pad_id is not None:
199
+ pad_internal = ((internal == tok.pad_id) & internal_attn).long().sum(dim=1).tolist()
200
+ pos0 = Counter(int(x) for x in ids[:, 0].tolist())
201
+ last_valid = []
202
+ for row, mask in zip(ids, attn):
203
+ idx = int(mask.long().sum().item()) - 1
204
+ if idx >= 0:
205
+ last_valid.append(int(row[idx].item()))
206
+ last_counts = Counter(last_valid)
207
+ record_lengths = list(
208
+ iter_record_lengths(
209
+ data_path=args.data_path,
210
+ tokenizer=tok,
211
+ mode=args.data_mode,
212
+ text_column=args.text_column,
213
+ openwebtext_split=args.openwebtext_split,
214
+ max_records=args.max_records,
215
+ )
216
+ )
217
+ out = {
218
+ "name": args.name,
219
+ "data_path": args.data_path,
220
+ "data_mode": args.data_mode,
221
+ "tokenizer_path": args.tokenizer_path,
222
+ "vocab_size": tok.vocab_size,
223
+ "bos_id": tok.bos_id,
224
+ "bos_piece": token_piece(tok, tok.bos_id),
225
+ "eos_id": tok.eos_id,
226
+ "eos_piece": token_piece(tok, tok.eos_id),
227
+ "pad_id": tok.pad_id,
228
+ "n_sequences": int(ids.size(0)),
229
+ "max_len": args.max_len,
230
+ "sequence_len": rate_summary([float(x) for x in seq_lens]),
231
+ "record_token_len_no_special_no_eos": rate_summary([float(x) for x in record_lengths]),
232
+ "internal_eos_per_seq": rate_summary([float(x) for x in eos_internal]),
233
+ "internal_pad_per_seq": rate_summary([float(x) for x in pad_internal]) if pad_internal else None,
234
+ "pos0_top": [
235
+ {"id": i, "piece": compact_piece(token_piece(tok, i)), "count": c, "rate": c / max(ids.size(0), 1)}
236
+ for i, c in pos0.most_common(args.top_k)
237
+ ],
238
+ "last_valid_top": [
239
+ {"id": i, "piece": compact_piece(token_piece(tok, i)), "count": c, "rate": c / max(len(last_valid), 1)}
240
+ for i, c in last_counts.most_common(args.top_k)
241
+ ],
242
+ "unigram_entropy": distribution_entropy_from_counts(counts),
243
+ "token_feature_rates": token_feature_rates(valid_ids, tok),
244
+ "top_unigram": top,
245
+ }
246
+ Path(args.out_json).parent.mkdir(parents=True, exist_ok=True)
247
+ Path(args.out_json).write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
248
+ print(json.dumps(out, indent=2, ensure_ascii=False), flush=True)
249
+
250
+
251
+ def ckpt_arg(ckpt_args: dict[str, Any], key: str, default: Any) -> Any:
252
+ return ckpt_args.get(key, default)
253
+
254
+
255
+ def make_bridge_for_eval(
256
+ *,
257
+ ids: torch.Tensor,
258
+ attn: torch.Tensor,
259
+ ckpt_args: dict[str, Any],
260
+ vocab_size: int,
261
+ t_value: float,
262
+ force_mask_ratio: float | None,
263
+ eps: float,
264
+ ) -> Any:
265
+ return make_dirichlet_bridge_batch(
266
+ ids=ids,
267
+ attn_mask=attn,
268
+ vocab_size=vocab_size,
269
+ target_prob=float(ckpt_arg(ckpt_args, "target_prob", 1.0)),
270
+ min_t=float(ckpt_arg(ckpt_args, "min_t", 0.0)),
271
+ max_t=float(ckpt_arg(ckpt_args, "max_t", 1.0)),
272
+ min_mask_ratio=float(ckpt_arg(ckpt_args, "min_mask_ratio", 0.1)),
273
+ max_mask_ratio=float(ckpt_arg(ckpt_args, "max_mask_ratio", 1.0)),
274
+ wrong_token_replace_prob=ckpt_arg(ckpt_args, "wrong_token_replace_prob", "0.0"),
275
+ wrong_token_schedule=str(ckpt_arg(ckpt_args, "wrong_token_schedule", "constant")),
276
+ wrong_token_exp_k=float(ckpt_arg(ckpt_args, "wrong_token_exp_k", 1.0)),
277
+ dirichlet_concentration_min=float(ckpt_arg(ckpt_args, "dirichlet_concentration_min", 1.0)),
278
+ dirichlet_concentration_max=float(ckpt_arg(ckpt_args, "dirichlet_concentration_max", 1024.0)),
279
+ eps=eps,
280
+ state_format=str(ckpt_arg(ckpt_args, "state_format", ckpt_arg(ckpt_args, "input_format", "prob"))),
281
+ dirichlet_endpoint_mode=str(ckpt_arg(ckpt_args, "dirichlet_endpoint_mode", "bernoulli_wrong")),
282
+ dirichlet_semantic_t_mode=str(ckpt_arg(ckpt_args, "dirichlet_semantic_t_mode", "same")),
283
+ dirichlet_semantic_t_value=float(ckpt_arg(ckpt_args, "dirichlet_semantic_t_value", 0.0)),
284
+ dirichlet_semantic_t_curve=str(ckpt_arg(ckpt_args, "dirichlet_semantic_t_curve", "linear")),
285
+ dirichlet_semantic_t_power=float(ckpt_arg(ckpt_args, "dirichlet_semantic_t_power", 1.0)),
286
+ dirichlet_support_t_curve=str(ckpt_arg(ckpt_args, "dirichlet_support_t_curve", "linear")),
287
+ dirichlet_support_t_power=float(ckpt_arg(ckpt_args, "dirichlet_support_t_power", 1.0)),
288
+ endpoint_sequence_random_prob_alpha=float(ckpt_arg(ckpt_args, "endpoint_sequence_random_prob_alpha", 0.0)),
289
+ categorical_wrong_from_full_vocab=bool(ckpt_arg(ckpt_args, "categorical_wrong_from_full_vocab", False)),
290
+ categorical_wrong_from_batch_valid_tokens=bool(ckpt_arg(ckpt_args, "categorical_wrong_from_batch_valid_tokens", False)),
291
+ categorical_wrong_basin_token_ids=ckpt_arg(ckpt_args, "categorical_wrong_basin_token_ids", ""),
292
+ categorical_wrong_basin_prob=float(ckpt_arg(ckpt_args, "categorical_wrong_basin_prob", 0.0)),
293
+ categorical_wrong_unigram_prob=float(ckpt_arg(ckpt_args, "categorical_wrong_unigram_prob", 0.0)),
294
+ categorical_wrong_uniform_prob=float(ckpt_arg(ckpt_args, "categorical_wrong_uniform_prob", 0.0)),
295
+ categorical_wrong_prob_floor=float(ckpt_arg(ckpt_args, "categorical_wrong_prob_floor", 0.0)),
296
+ categorical_gold_prob_floor=float(ckpt_arg(ckpt_args, "categorical_gold_prob_floor", 0.0)),
297
+ categorical_gold_prob_ceil=float(ckpt_arg(ckpt_args, "categorical_gold_prob_ceil", 1.0)),
298
+ simplex_bridge_sampler=str(ckpt_arg(ckpt_args, "simplex_bridge_sampler", "dirichlet")),
299
+ logistic_normal_sigma_min=float(ckpt_arg(ckpt_args, "logistic_normal_sigma_min", 0.18)),
300
+ logistic_normal_sigma_max=float(ckpt_arg(ckpt_args, "logistic_normal_sigma_max", 2.2)),
301
+ logistic_normal_tau_min=float(ckpt_arg(ckpt_args, "logistic_normal_tau_min", 0.65)),
302
+ logistic_normal_tau_max=float(ckpt_arg(ckpt_args, "logistic_normal_tau_max", 1.15)),
303
+ force_t=t_value,
304
+ force_mask_ratio=force_mask_ratio,
305
+ mask_ratio_floor_schedule=str(ckpt_arg(ckpt_args, "mask_ratio_floor_schedule", "none")),
306
+ mask_mixture_original_prob=float(ckpt_arg(ckpt_args, "mask_mixture_original_prob", 0.0)),
307
+ mask_mixture_lowk_prob=float(ckpt_arg(ckpt_args, "mask_mixture_lowk_prob", 0.0)),
308
+ mask_mixture_lowcorrupt_prob=float(ckpt_arg(ckpt_args, "mask_mixture_lowcorrupt_prob", 0.0)),
309
+ mask_mixture_block_prob=float(ckpt_arg(ckpt_args, "mask_mixture_block_prob", 0.0)),
310
+ mask_mixture_all_prob=float(ckpt_arg(ckpt_args, "mask_mixture_all_prob", 0.0)),
311
+ mask_mixture_lowk_clean_tokens=ckpt_arg(ckpt_args, "mask_mixture_lowk_clean_tokens", "1,2,4,8,16,32,64"),
312
+ mask_mixture_lowcorrupt_tokens=ckpt_arg(ckpt_args, "mask_mixture_lowcorrupt_tokens", "1,2,4,8,16,32,64"),
313
+ mask_mixture_block_tokens=ckpt_arg(ckpt_args, "mask_mixture_block_tokens", "64,128"),
314
+ clean_state_mode=str(ckpt_arg(ckpt_args, "clean_state_mode", "onehot")),
315
+ return_dense_targets=False,
316
+ )
317
+
318
+
319
+ def masked_loss_acc(logits: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> dict[str, float]:
320
+ flat_mask = mask.reshape(-1)
321
+ if not bool(flat_mask.any().item()):
322
+ return {"nll": 0.0, "ppl": 1.0, "acc": 0.0, "tokens": 0}
323
+ flat_logits = logits.reshape(-1, logits.size(-1))[flat_mask]
324
+ flat_target = target.reshape(-1)[flat_mask]
325
+ loss = F.cross_entropy(flat_logits, flat_target, reduction="mean")
326
+ pred = flat_logits.argmax(dim=-1)
327
+ acc = (pred == flat_target).float().mean()
328
+ return {
329
+ "nll": float(loss.detach().cpu()),
330
+ "ppl": float(torch.exp(loss.clamp(max=50)).detach().cpu()),
331
+ "acc": float(acc.detach().cpu()),
332
+ "tokens": int(flat_mask.sum().detach().cpu()),
333
+ }
334
+
335
+
336
+ @torch.inference_mode()
337
+ def command_teacher(args: argparse.Namespace) -> None:
338
+ tok = BpeTextTokenizer.from_file(args.tokenizer_path)
339
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
340
+ ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
341
+ ckpt_args = dict(ckpt.get("args", {}))
342
+ model = build_model_from_ckpt(ckpt, tok.vocab_size, args.max_len, device).eval()
343
+ batch = load_batch(
344
+ data_path=args.data_path,
345
+ tokenizer=tok,
346
+ max_len=args.max_len,
347
+ batch_size=args.batch_size,
348
+ mode=args.data_mode,
349
+ text_column=args.text_column,
350
+ openwebtext_split=args.openwebtext_split,
351
+ wrap_mode=args.wrap_mode,
352
+ max_records=args.max_records,
353
+ tokenized_pad_token=args.tokenized_pad_token,
354
+ )
355
+ ids = batch["ids"].to(device)
356
+ attn = batch.get("attn_mask", torch.ones_like(ids, dtype=torch.bool)).to(device)
357
+ rows = []
358
+ for t_value in [float(x) for x in args.t_values.split(",") if x.strip()]:
359
+ torch.manual_seed(args.seed + int(round(t_value * 1000000)))
360
+ bridge = make_bridge_for_eval(
361
+ ids=ids,
362
+ attn=attn,
363
+ ckpt_args=ckpt_args,
364
+ vocab_size=tok.vocab_size,
365
+ t_value=t_value,
366
+ force_mask_ratio=args.force_mask_ratio,
367
+ eps=args.eps,
368
+ )
369
+ model_t = bridge.t
370
+ logits = model(state_for_model(model, bridge.state, args.eps), model_t, attn).float()
371
+ valid = attn
372
+ corrupt = bridge.corrupt_mask & attn
373
+ pos0_pred = logits[:, 0].argmax(dim=-1)
374
+ last_pred = []
375
+ for b in range(ids.size(0)):
376
+ last = int(attn[b].long().sum().item()) - 1
377
+ last_pred.append(int(logits[b, last].argmax().detach().cpu()) if last >= 0 else -1)
378
+ pos0_counts = Counter(int(x) for x in pos0_pred.detach().cpu().tolist())
379
+ last_counts = Counter(last_pred)
380
+ probs = F.softmax(logits, dim=-1)
381
+ rows.append(
382
+ {
383
+ "name": args.name,
384
+ "checkpoint": args.checkpoint,
385
+ "ckpt_step": int(ckpt.get("step", -1)),
386
+ "t": t_value,
387
+ "force_mask_ratio": args.force_mask_ratio,
388
+ "corrupt_frac": float(corrupt.float().mean().detach().cpu()),
389
+ "wrong_frac": float((bridge.wrong_mask & attn).float().sum().detach().cpu() / attn.float().sum().clamp_min(1).detach().cpu()),
390
+ "valid": masked_loss_acc(logits, ids, valid),
391
+ "corrupt": masked_loss_acc(logits, ids, corrupt),
392
+ "dist_entropy": float((-(probs.clamp_min(args.eps) * probs.clamp_min(args.eps).log()).sum(dim=-1)[valid]).mean().detach().cpu()),
393
+ "mean_maxp": float(probs.max(dim=-1).values[valid].mean().detach().cpu()),
394
+ "pos0_gold_id": int(ids[0, 0].detach().cpu()),
395
+ "pos0_gold_piece": token_piece(tok, int(ids[0, 0].detach().cpu())),
396
+ "pos0_top": [
397
+ {"id": i, "piece": compact_piece(token_piece(tok, i)), "count": c, "rate": c / max(ids.size(0), 1)}
398
+ for i, c in pos0_counts.most_common(5)
399
+ ],
400
+ "last_top": [
401
+ {"id": i, "piece": compact_piece(token_piece(tok, i)), "count": c, "rate": c / max(ids.size(0), 1)}
402
+ for i, c in last_counts.most_common(5)
403
+ ],
404
+ }
405
+ )
406
+ out = Path(args.out_json)
407
+ out.parent.mkdir(parents=True, exist_ok=True)
408
+ out.write_text(json.dumps(rows, indent=2, ensure_ascii=False), encoding="utf-8")
409
+ with out.with_suffix(".tsv").open("w", newline="", encoding="utf-8") as f:
410
+ fields = [
411
+ "name",
412
+ "ckpt_step",
413
+ "t",
414
+ "force_mask_ratio",
415
+ "corrupt_frac",
416
+ "wrong_frac",
417
+ "valid_nll",
418
+ "valid_acc",
419
+ "corrupt_nll",
420
+ "corrupt_acc",
421
+ "dist_entropy",
422
+ "mean_maxp",
423
+ "pos0_gold_piece",
424
+ "pos0_top",
425
+ "last_top",
426
+ ]
427
+ writer = csv.DictWriter(f, fieldnames=fields, delimiter="\t")
428
+ writer.writeheader()
429
+ for row in rows:
430
+ writer.writerow(
431
+ {
432
+ "name": row["name"],
433
+ "ckpt_step": row["ckpt_step"],
434
+ "t": row["t"],
435
+ "force_mask_ratio": row["force_mask_ratio"],
436
+ "corrupt_frac": row["corrupt_frac"],
437
+ "wrong_frac": row["wrong_frac"],
438
+ "valid_nll": row["valid"]["nll"],
439
+ "valid_acc": row["valid"]["acc"],
440
+ "corrupt_nll": row["corrupt"]["nll"],
441
+ "corrupt_acc": row["corrupt"]["acc"],
442
+ "dist_entropy": row["dist_entropy"],
443
+ "mean_maxp": row["mean_maxp"],
444
+ "pos0_gold_piece": row["pos0_gold_piece"],
445
+ "pos0_top": " | ".join(f"{x['piece']}:{x['rate']:.2f}" for x in row["pos0_top"]),
446
+ "last_top": " | ".join(f"{x['piece']}:{x['rate']:.2f}" for x in row["last_top"]),
447
+ }
448
+ )
449
+ for row in rows:
450
+ print(
451
+ f"{row['name']} step={row['ckpt_step']} t={row['t']:.4f} "
452
+ f"valid_nll={row['valid']['nll']:.3f} valid_acc={row['valid']['acc']:.3f} "
453
+ f"corrupt_nll={row['corrupt']['nll']:.3f} corrupt_acc={row['corrupt']['acc']:.3f} "
454
+ f"pos0={row['pos0_top'][0]['piece']}:{row['pos0_top'][0]['rate']:.2f}",
455
+ flush=True,
456
+ )
457
+
458
+
459
+ def filter_top_p(probs: torch.Tensor, top_p: float, eps: float) -> torch.Tensor:
460
+ if top_p >= 1.0:
461
+ return probs
462
+ sorted_vals, sorted_idx = torch.sort(probs, dim=-1, descending=True)
463
+ total = sorted_vals.sum(dim=-1, keepdim=True).clamp_min(eps)
464
+ remove = sorted_vals.cumsum(dim=-1) > top_p * total
465
+ remove[..., 0] = False
466
+ sorted_vals = sorted_vals.masked_fill(remove, 0.0)
467
+ out = torch.zeros_like(probs).scatter(-1, sorted_idx, sorted_vals)
468
+ return out / out.sum(dim=-1, keepdim=True).clamp_min(eps)
469
+
470
+
471
+ def distribution_metrics(probs: torch.Tensor, ids: torch.Tensor, tok: BpeTextTokenizer, prefix: str) -> dict[str, Any]:
472
+ p = probs.clamp_min(1e-12)
473
+ ent = float((-(p * p.log()).sum(dim=-1)).mean().detach().cpu())
474
+ maxp, arg = probs.max(dim=-1)
475
+ counts = Counter(int(x) for x in arg.reshape(-1).detach().cpu().tolist())
476
+ return {
477
+ f"{prefix}_entropy": ent,
478
+ f"{prefix}_mean_top_mass": float(maxp.mean().detach().cpu()),
479
+ f"{prefix}_argmax_token_entropy": distribution_entropy_from_counts(counts),
480
+ f"{prefix}_argmax_top": [
481
+ {"id": i, "piece": compact_piece(token_piece(tok, i)), "count": c, "rate": c / max(arg.numel(), 1)}
482
+ for i, c in counts.most_common(8)
483
+ ],
484
+ **{f"{prefix}_{k}": v for k, v in token_feature_rates(arg.detach().cpu(), tok).items()},
485
+ }
486
+
487
+
488
+ @torch.inference_mode()
489
+ def command_trace(args: argparse.Namespace) -> None:
490
+ tok = BpeTextTokenizer.from_file(args.tokenizer_path)
491
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
492
+ torch.manual_seed(args.seed)
493
+ ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
494
+ model = build_model_from_ckpt(ckpt, tok.vocab_size, args.max_len, device).eval()
495
+ eps = args.eps
496
+ bs = args.batch_size
497
+ probs = sample_noise_simplex(
498
+ (bs, args.max_len),
499
+ tok.vocab_size,
500
+ device,
501
+ eps,
502
+ noise_mode="dirichlet",
503
+ target_prob=1.0,
504
+ noise_sigma=-1.0,
505
+ dirichlet_concentration=args.concentration_min,
506
+ )
507
+ attn = torch.ones((bs, args.max_len), dtype=torch.bool, device=device)
508
+ log_cmin = math.log(args.concentration_min)
509
+ log_cmax = math.log(args.concentration_max)
510
+ out = Path(args.out_jsonl)
511
+ out.parent.mkdir(parents=True, exist_ok=True)
512
+ snapshot = set(int(x) for x in args.trace_steps.split(",") if x.strip())
513
+ last_endpoint = probs
514
+ with out.open("w", encoding="utf-8") as f:
515
+ for step in range(args.steps):
516
+ support_t = (step + 1) / max(args.steps, 1)
517
+ t = torch.full((bs,), support_t, dtype=torch.float32, device=device)
518
+ logits = model(state_for_model(model, probs, eps), t, attn).float()
519
+ endpoint = F.softmax(logits / args.endpoint_temp, dim=-1)
520
+ endpoint = filter_top_p(endpoint, args.endpoint_top_p, eps)
521
+ tau = args.gumbel_tau_start + support_t * (args.gumbel_tau_end - args.gumbel_tau_start)
522
+ uniform = torch.rand_like(endpoint).clamp_(eps, 1.0 - eps)
523
+ gumbel = -torch.log(-torch.log(uniform))
524
+ projected = F.softmax((endpoint.clamp_min(eps).log() + gumbel) / max(tau, eps), dim=-1)
525
+ last_endpoint = projected
526
+ mean = (1.0 - support_t) / tok.vocab_size + support_t * projected
527
+ mean = mean / mean.sum(dim=-1, keepdim=True).clamp_min(eps)
528
+ conc = math.exp(log_cmin + support_t * (log_cmax - log_cmin))
529
+ alpha = (mean * conc).clamp_min(eps)
530
+ probs = torch._standard_gamma(alpha).clamp_min(eps)
531
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(eps)
532
+ step_num = step + 1
533
+ if step_num in snapshot or step_num == args.steps:
534
+ row = {
535
+ "name": args.name,
536
+ "ckpt_step": int(ckpt.get("step", -1)),
537
+ "step": step_num,
538
+ "support_t": support_t,
539
+ "tau": tau,
540
+ "concentration": conc,
541
+ }
542
+ row.update(distribution_metrics(endpoint, endpoint.argmax(dim=-1), tok, "a"))
543
+ row.update(distribution_metrics(projected, projected.argmax(dim=-1), tok, "e"))
544
+ row.update(distribution_metrics(probs, probs.argmax(dim=-1), tok, "p"))
545
+ for pos in [0, 1, args.max_len - 2, args.max_len - 1]:
546
+ a_id = int(endpoint[0, pos].argmax().detach().cpu())
547
+ e_id = int(projected[0, pos].argmax().detach().cpu())
548
+ p_id = int(probs[0, pos].argmax().detach().cpu())
549
+ row[f"pos{pos}_a"] = {"id": a_id, "piece": compact_piece(token_piece(tok, a_id)), "prob": float(endpoint[0, pos, a_id].detach().cpu())}
550
+ row[f"pos{pos}_e"] = {"id": e_id, "piece": compact_piece(token_piece(tok, e_id)), "prob": float(projected[0, pos, e_id].detach().cpu())}
551
+ row[f"pos{pos}_p"] = {"id": p_id, "piece": compact_piece(token_piece(tok, p_id)), "prob": float(probs[0, pos, p_id].detach().cpu())}
552
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
553
+ print(
554
+ f"{args.name} step={step_num} aH={row['a_entropy']:.2f} eH={row['e_entropy']:.2f} pH={row['p_entropy']:.2f} "
555
+ f"a_top={row['a_argmax_top'][0]['piece']}:{row['a_argmax_top'][0]['rate']:.2f} "
556
+ f"p_top={row['p_argmax_top'][0]['piece']}:{row['p_argmax_top'][0]['rate']:.2f}",
557
+ flush=True,
558
+ )
559
+ if args.final_out:
560
+ final_probs = 0.5 * probs + 0.5 * last_endpoint
561
+ ids = final_probs.argmax(dim=-1).detach().cpu().tolist()
562
+ Path(args.final_out).write_text("\n\n".join(tok.decode(row, stop_at_eos=False, skip_special_tokens=False) for row in ids), encoding="utf-8")
563
+
564
+
565
+ def main() -> None:
566
+ ap = argparse.ArgumentParser()
567
+ sub = ap.add_subparsers(dest="cmd", required=True)
568
+ data = sub.add_parser("data")
569
+ data.add_argument("--name", required=True)
570
+ data.add_argument("--data_path", required=True)
571
+ data.add_argument("--tokenizer_path", required=True)
572
+ data.add_argument("--out_json", required=True)
573
+ data.add_argument("--data_mode", choices=["wrap", "tokenized_hf"], default="wrap")
574
+ data.add_argument("--text_column", default=None)
575
+ data.add_argument("--openwebtext_split", default="all")
576
+ data.add_argument("--wrap_mode", default="stream")
577
+ data.add_argument("--tokenized_pad_token", default="pad")
578
+ data.add_argument("--max_len", type=int, default=1024)
579
+ data.add_argument("--n_sequences", type=int, default=2048)
580
+ data.add_argument("--max_records", type=int, default=20000)
581
+ data.add_argument("--top_k", type=int, default=24)
582
+ data.set_defaults(func=command_data)
583
+
584
+ teacher = sub.add_parser("teacher")
585
+ teacher.add_argument("--name", required=True)
586
+ teacher.add_argument("--checkpoint", required=True)
587
+ teacher.add_argument("--data_path", required=True)
588
+ teacher.add_argument("--tokenizer_path", required=True)
589
+ teacher.add_argument("--out_json", required=True)
590
+ teacher.add_argument("--data_mode", choices=["wrap", "tokenized_hf"], default="wrap")
591
+ teacher.add_argument("--text_column", default=None)
592
+ teacher.add_argument("--openwebtext_split", default="all")
593
+ teacher.add_argument("--wrap_mode", default="stream")
594
+ teacher.add_argument("--tokenized_pad_token", default="pad")
595
+ teacher.add_argument("--max_len", type=int, default=1024)
596
+ teacher.add_argument("--batch_size", type=int, default=8)
597
+ teacher.add_argument("--max_records", type=int, default=20000)
598
+ teacher.add_argument("--t_values", default="0.0,0.0078125,0.03125,0.125,0.5,1.0")
599
+ teacher.add_argument("--force_mask_ratio", type=float, default=None)
600
+ teacher.add_argument("--seed", type=int, default=20260525)
601
+ teacher.add_argument("--eps", type=float, default=1e-8)
602
+ teacher.add_argument("--cpu", action="store_true")
603
+ teacher.set_defaults(func=command_teacher)
604
+
605
+ trace = sub.add_parser("trace")
606
+ trace.add_argument("--name", required=True)
607
+ trace.add_argument("--checkpoint", required=True)
608
+ trace.add_argument("--tokenizer_path", required=True)
609
+ trace.add_argument("--out_jsonl", required=True)
610
+ trace.add_argument("--final_out", default="")
611
+ trace.add_argument("--max_len", type=int, default=1024)
612
+ trace.add_argument("--batch_size", type=int, default=2)
613
+ trace.add_argument("--steps", type=int, default=128)
614
+ trace.add_argument("--trace_steps", default="1,2,4,8,16,32,64,96,128")
615
+ trace.add_argument("--concentration_min", type=float, default=30522)
616
+ trace.add_argument("--concentration_max", type=float, default=61044)
617
+ trace.add_argument("--endpoint_temp", type=float, default=1.45)
618
+ trace.add_argument("--endpoint_top_p", type=float, default=0.95)
619
+ trace.add_argument("--gumbel_tau_start", type=float, default=1.0)
620
+ trace.add_argument("--gumbel_tau_end", type=float, default=0.2)
621
+ trace.add_argument("--seed", type=int, default=20260525)
622
+ trace.add_argument("--eps", type=float, default=1e-8)
623
+ trace.add_argument("--cpu", action="store_true")
624
+ trace.set_defaults(func=command_trace)
625
+
626
+ args = ap.parse_args()
627
+ args.func(args)
628
+
629
+
630
+ if __name__ == "__main__":
631
+ main()
LTA_openwebtext_dualt/scripts/build_lta_owt_compact_gpt2bpe_stream1024_train_minus_100k_np8.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export PACKING_MODE="${PACKING_MODE:-stream_chunks}"
7
+ export OUTPUT_SUFFIX="${OUTPUT_SUFFIX:-stream1024}"
8
+ export CACHE_SUFFIX="${CACHE_SUFFIX:-_stream1024}"
9
+ export LOG_DIR="${LOG_DIR:-logs/data_build_compact_gpt2bpe_stream1024}"
10
+ export VOCAB_SIZES="${VOCAB_SIZES:-2048,4096,8192}"
11
+ export NUM_PROC="${NUM_PROC:-8}"
12
+
13
+ exec bash scripts/build_lta_owt_compact_gpt2bpe_packed_train_minus_100k_np8.sh "$@"
LTA_openwebtext_dualt/scripts/build_owt_t5_elf_dataset.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ import os
7
+ import shutil
8
+ from pathlib import Path
9
+ from typing import Iterator
10
+
11
+
12
+ def parse_args() -> argparse.Namespace:
13
+ p = argparse.ArgumentParser(
14
+ description=(
15
+ "Build an ELF-style OpenWebText T5 token dataset. By default each raw "
16
+ "record is tokenized with add_special_tokens=False, overlength records "
17
+ "are split into max_len chunks, and short records stay short. The "
18
+ "packed_records mode instead concatenates EOS-terminated records up to "
19
+ "max_len while preserving record boundaries. stream_chunks concatenates "
20
+ "the token stream and slices exact max_len chunks, so chunk boundaries "
21
+ "are defined by the selected tokenizer."
22
+ )
23
+ )
24
+ p.add_argument("--data_path", required=True)
25
+ p.add_argument("--output_dir", required=True)
26
+ p.add_argument("--tokenizer_path", required=True)
27
+ p.add_argument("--text_column", default="text")
28
+ p.add_argument("--txt_record_mode", choices=["auto", "line", "eot"], default="auto")
29
+ p.add_argument("--openwebtext_split", choices=["all", "train_minus_100k", "valid_last_100k"], default="all")
30
+ p.add_argument("--openwebtext_valid_records", type=int, default=100_000)
31
+ p.add_argument("--detokenizer", default="auto")
32
+ p.add_argument("--max_len", type=int, default=1024)
33
+ p.add_argument(
34
+ "--packing_mode",
35
+ choices=["record_chunks", "packed_records", "stream_chunks"],
36
+ default="record_chunks",
37
+ help=(
38
+ "record_chunks preserves the old behavior. packed_records appends EOS "
39
+ "per record and packs multiple records into near-max_len examples. "
40
+ "stream_chunks appends EOS per record, concatenates records, and emits "
41
+ "exact max_len chunks across record boundaries."
42
+ ),
43
+ )
44
+ p.add_argument("--max_records", type=int, default=0)
45
+ p.add_argument("--min_len", type=int, default=1)
46
+ p.add_argument("--add_eos", action="store_true", help="Append tokenizer EOS to each raw record before chunking.")
47
+ p.add_argument("--add_special_tokens", action="store_true", help="Let the tokenizer add model special tokens.")
48
+ p.add_argument("--cache_dir", default="")
49
+ p.add_argument("--max_shard_size", default="500MB")
50
+ p.add_argument("--num_proc", type=int, default=max(1, min(32, (os.cpu_count() or 8) // 2)))
51
+ p.add_argument("--tokenize_batch_size", type=int, default=1024)
52
+ p.add_argument(
53
+ "--merge_parts",
54
+ action="store_true",
55
+ help="After parallel part build, merge into one save_to_disk dataset. Slower but portable.",
56
+ )
57
+ p.add_argument("--keep_parts", action="store_true")
58
+ p.add_argument("--resume_parts", action="store_true", help="Keep completed part-* directories and build only missing parts.")
59
+ p.add_argument("--stats_only", action="store_true")
60
+ p.add_argument("--overwrite", action="store_true")
61
+ return p.parse_args()
62
+
63
+
64
+ def _iter_examples(
65
+ *,
66
+ data_path: str,
67
+ tokenizer_path: str,
68
+ text_column: str | None,
69
+ txt_record_mode: str,
70
+ openwebtext_split: str,
71
+ openwebtext_valid_records: int,
72
+ detokenizer: str | None,
73
+ max_len: int,
74
+ packing_mode: str,
75
+ max_records: int,
76
+ min_len: int,
77
+ add_eos: bool,
78
+ add_special_tokens: bool,
79
+ ) -> Iterator[dict]:
80
+ from flowtext_lab.data import iter_text_records
81
+ from flowtext_lab.tokenization import BpeTextTokenizer
82
+
83
+ tokenizer = BpeTextTokenizer.from_file(tokenizer_path)
84
+ seen_records = 0
85
+ pack: list[int] = []
86
+
87
+ def emit_ids(ids: list[int]) -> dict:
88
+ return {
89
+ "input_ids": [int(x) for x in ids],
90
+ "sequence_length": int(len(ids)),
91
+ }
92
+
93
+ def iter_record_chunks(ids: list[int]) -> Iterator[dict]:
94
+ for start in range(0, len(ids), max_len):
95
+ chunk = ids[start : start + max_len]
96
+ if len(chunk) >= min_len:
97
+ yield emit_ids(chunk)
98
+ if start + max_len >= len(ids):
99
+ break
100
+
101
+ def flush_pack() -> Iterator[dict]:
102
+ nonlocal pack
103
+ if len(pack) >= min_len:
104
+ yield emit_ids(pack)
105
+ pack = []
106
+
107
+ def append_stream(ids: list[int]) -> Iterator[dict]:
108
+ nonlocal pack
109
+ pack.extend(int(x) for x in ids)
110
+ while len(pack) >= max_len:
111
+ yield emit_ids(pack[:max_len])
112
+ pack = pack[max_len:]
113
+
114
+ for text in iter_text_records(
115
+ data_path,
116
+ text_column=text_column,
117
+ txt_record_mode=txt_record_mode,
118
+ openwebtext_split=openwebtext_split,
119
+ openwebtext_valid_records=openwebtext_valid_records,
120
+ detokenizer=detokenizer,
121
+ ):
122
+ if not text:
123
+ continue
124
+ ids = tokenizer.encode(text, add_eos=add_eos, add_special_tokens=add_special_tokens)
125
+ if not ids:
126
+ continue
127
+ if packing_mode == "record_chunks":
128
+ yield from iter_record_chunks(ids)
129
+ elif packing_mode == "packed_records":
130
+ if len(ids) > max_len:
131
+ yield from flush_pack()
132
+ yield from iter_record_chunks(ids)
133
+ else:
134
+ if pack and len(pack) + len(ids) > max_len:
135
+ yield from flush_pack()
136
+ pack.extend(int(x) for x in ids)
137
+ if len(pack) >= max_len:
138
+ yield from flush_pack()
139
+ else:
140
+ yield from append_stream(ids)
141
+ seen_records += 1
142
+ if max_records > 0 and seen_records >= max_records:
143
+ break
144
+ if packing_mode in ("packed_records", "stream_chunks"):
145
+ yield from flush_pack()
146
+
147
+
148
+ def _stats(args: argparse.Namespace) -> dict:
149
+ num_examples = 0
150
+ total_tokens = 0
151
+ min_len = None
152
+ max_len = 0
153
+ hist = {"lt128": 0, "128_255": 0, "256_511": 0, "512_1023": 0, "eq1024": 0}
154
+ for ex in _iter_examples(**_gen_kwargs(args)):
155
+ length = int(ex["sequence_length"])
156
+ num_examples += 1
157
+ total_tokens += length
158
+ min_len = length if min_len is None else min(min_len, length)
159
+ max_len = max(max_len, length)
160
+ if length < 128:
161
+ hist["lt128"] += 1
162
+ elif length < 256:
163
+ hist["128_255"] += 1
164
+ elif length < 512:
165
+ hist["256_511"] += 1
166
+ elif length < args.max_len:
167
+ hist["512_1023"] += 1
168
+ else:
169
+ hist["eq1024"] += 1
170
+ return {
171
+ "num_examples": int(num_examples),
172
+ "total_tokens": int(total_tokens),
173
+ "mean_length": float(total_tokens / num_examples) if num_examples else 0.0,
174
+ "min_length": int(min_len or 0),
175
+ "max_length": int(max_len),
176
+ "length_hist": hist,
177
+ }
178
+
179
+
180
+ def _gen_kwargs(args: argparse.Namespace) -> dict:
181
+ return {
182
+ "data_path": args.data_path,
183
+ "tokenizer_path": args.tokenizer_path,
184
+ "text_column": args.text_column,
185
+ "txt_record_mode": args.txt_record_mode,
186
+ "openwebtext_split": args.openwebtext_split,
187
+ "openwebtext_valid_records": args.openwebtext_valid_records,
188
+ "detokenizer": args.detokenizer,
189
+ "max_len": int(args.max_len),
190
+ "packing_mode": args.packing_mode,
191
+ "max_records": int(args.max_records),
192
+ "min_len": int(args.min_len),
193
+ "add_eos": bool(args.add_eos),
194
+ "add_special_tokens": bool(args.add_special_tokens),
195
+ }
196
+
197
+
198
+ def _make_limited_specs(args: argparse.Namespace) -> list[tuple[str, int, int | None]]:
199
+ from flowtext_lab.data import _make_file_specs
200
+
201
+ root = Path(args.data_path)
202
+ if root.is_dir():
203
+ files = sorted(
204
+ p for p in root.rglob("*")
205
+ if p.suffix.lower() in {".txt", ".jsonl", ".json", ".parquet"}
206
+ )
207
+ else:
208
+ files = [root]
209
+ specs = _make_file_specs(files, args.openwebtext_split, int(args.openwebtext_valid_records))
210
+ if args.max_records <= 0:
211
+ return [(str(p), int(a), None if b is None else int(b)) for p, a, b in specs]
212
+
213
+ limited = []
214
+ remaining = int(args.max_records)
215
+ for path, start, stop in specs:
216
+ if remaining <= 0:
217
+ break
218
+ if stop is None:
219
+ limited.append((str(path), int(start), None))
220
+ break
221
+ count = max(0, int(stop) - int(start))
222
+ take = min(count, remaining)
223
+ if take > 0:
224
+ limited.append((str(path), int(start), int(start) + take))
225
+ remaining -= take
226
+ return limited
227
+
228
+
229
+ def _iter_parquet_text_batches(
230
+ path: Path,
231
+ *,
232
+ text_column: str | None,
233
+ row_start: int,
234
+ row_stop: int | None,
235
+ batch_size: int,
236
+ ) -> Iterator[list[str]]:
237
+ import pyarrow.parquet as pq
238
+
239
+ pf = pq.ParquetFile(path)
240
+ col = text_column
241
+ if col is None:
242
+ names = set(pf.schema_arrow.names)
243
+ col = next((c for c in ("text", "content", "document", "article", "sentence") if c in names), None)
244
+ if col is None:
245
+ raise ValueError(f"Could not infer text column for {path}")
246
+
247
+ offset = 0
248
+ stop = pf.metadata.num_rows if row_stop is None else min(row_stop, pf.metadata.num_rows)
249
+ for batch in pf.iter_batches(columns=[col], batch_size=batch_size):
250
+ batch_start = offset
251
+ batch_stop = offset + batch.num_rows
252
+ offset = batch_stop
253
+ if batch_stop <= row_start:
254
+ continue
255
+ if batch_start >= stop:
256
+ break
257
+ local_start = max(0, row_start - batch_start)
258
+ local_stop = min(batch.num_rows, stop - batch_start)
259
+ values = batch.column(0).slice(local_start, local_stop - local_start).to_pylist()
260
+ texts = [str(value) for value in values if value is not None and str(value)]
261
+ if texts:
262
+ yield texts
263
+
264
+
265
+ def _iter_part_examples(
266
+ *,
267
+ spec: tuple[str, int, int | None],
268
+ tokenizer_path: str,
269
+ text_column: str | None,
270
+ detokenizer: str | None,
271
+ max_len: int,
272
+ packing_mode: str,
273
+ min_len: int,
274
+ add_eos: bool,
275
+ add_special_tokens: bool,
276
+ tokenize_batch_size: int,
277
+ ) -> Iterator[dict]:
278
+ from flowtext_lab.text_detokenization import detokenize_text, infer_detokenizer_name
279
+ from flowtext_lab.tokenization import BpeTextTokenizer
280
+
281
+ path = Path(spec[0])
282
+ row_start = int(spec[1])
283
+ row_stop = None if spec[2] is None else int(spec[2])
284
+ tokenizer = BpeTextTokenizer.from_file(tokenizer_path)
285
+ resolved_detok = infer_detokenizer_name(raw_path=str(path), explicit=detokenizer)
286
+ pack: list[int] = []
287
+
288
+ def emit_ids(ids: list[int]) -> dict:
289
+ return {
290
+ "input_ids": [int(x) for x in ids],
291
+ "sequence_length": int(len(ids)),
292
+ }
293
+
294
+ def iter_record_chunks(ids: list[int]) -> Iterator[dict]:
295
+ for start in range(0, len(ids), max_len):
296
+ chunk = ids[start : start + max_len]
297
+ if len(chunk) >= min_len:
298
+ yield emit_ids(chunk)
299
+ if start + max_len >= len(ids):
300
+ break
301
+
302
+ def flush_pack() -> Iterator[dict]:
303
+ nonlocal pack
304
+ if len(pack) >= min_len:
305
+ yield emit_ids(pack)
306
+ pack = []
307
+
308
+ def append_stream(ids: list[int]) -> Iterator[dict]:
309
+ nonlocal pack
310
+ pack.extend(int(x) for x in ids)
311
+ while len(pack) >= max_len:
312
+ yield emit_ids(pack[:max_len])
313
+ pack = pack[max_len:]
314
+
315
+ for texts in _iter_parquet_text_batches(
316
+ path,
317
+ text_column=text_column,
318
+ row_start=row_start,
319
+ row_stop=row_stop,
320
+ batch_size=max(1, int(tokenize_batch_size)),
321
+ ):
322
+ if resolved_detok:
323
+ texts = [detokenize_text(text, resolved_detok) for text in texts]
324
+ encoded = tokenizer.tokenizer.encode_batch(texts, add_special_tokens=add_special_tokens)
325
+ for enc in encoded:
326
+ ids = list(enc.ids)
327
+ if add_eos:
328
+ ids.append(tokenizer.eos_id)
329
+ if not ids:
330
+ continue
331
+ if packing_mode == "record_chunks":
332
+ yield from iter_record_chunks(ids)
333
+ elif packing_mode == "packed_records":
334
+ if len(ids) > max_len:
335
+ yield from flush_pack()
336
+ yield from iter_record_chunks(ids)
337
+ else:
338
+ if pack and len(pack) + len(ids) > max_len:
339
+ yield from flush_pack()
340
+ pack.extend(int(x) for x in ids)
341
+ if len(pack) >= max_len:
342
+ yield from flush_pack()
343
+ else:
344
+ yield from append_stream(ids)
345
+ if packing_mode in ("packed_records", "stream_chunks"):
346
+ yield from flush_pack()
347
+
348
+
349
+ def _build_part(task: dict) -> dict:
350
+ from datasets import Dataset, Features, Sequence, Value, disable_progress_bars
351
+
352
+ disable_progress_bars()
353
+
354
+ part_dir = Path(task["part_dir"])
355
+ if part_dir.exists():
356
+ shutil.rmtree(part_dir)
357
+ features = Features(
358
+ {
359
+ "input_ids": Sequence(Value("int32")),
360
+ "sequence_length": Value("int64"),
361
+ }
362
+ )
363
+ ds = Dataset.from_generator(
364
+ _iter_part_examples,
365
+ gen_kwargs={
366
+ "spec": task["spec"],
367
+ "tokenizer_path": task["tokenizer_path"],
368
+ "text_column": task["text_column"],
369
+ "detokenizer": task["detokenizer"],
370
+ "max_len": task["max_len"],
371
+ "packing_mode": task["packing_mode"],
372
+ "min_len": task["min_len"],
373
+ "add_eos": task["add_eos"],
374
+ "add_special_tokens": task["add_special_tokens"],
375
+ "tokenize_batch_size": task["tokenize_batch_size"],
376
+ },
377
+ features=features,
378
+ cache_dir=task["cache_dir"] or None,
379
+ )
380
+ ds.save_to_disk(str(part_dir), max_shard_size=task["max_shard_size"])
381
+ lengths = ds["sequence_length"] if len(ds) else []
382
+ total_tokens = int(sum(int(x) for x in lengths))
383
+ if task["cache_dir"]:
384
+ shutil.rmtree(task["cache_dir"], ignore_errors=True)
385
+ return {
386
+ "part_dir": str(part_dir),
387
+ "num_examples": int(len(ds)),
388
+ "total_tokens": total_tokens,
389
+ "spec": task["spec"],
390
+ }
391
+
392
+
393
+ def _part_is_complete(part_dir: Path) -> bool:
394
+ return (part_dir / "state.json").exists() and any(part_dir.glob("data-*.arrow"))
395
+
396
+
397
+ def _summarize_part(part_dir: Path, spec: tuple[str, int, int | None]) -> dict:
398
+ from datasets import load_from_disk
399
+
400
+ ds = load_from_disk(str(part_dir))
401
+ lengths = ds["sequence_length"] if len(ds) else []
402
+ total_tokens = int(sum(int(x) for x in lengths))
403
+ return {
404
+ "part_dir": str(part_dir),
405
+ "num_examples": int(len(ds)),
406
+ "total_tokens": total_tokens,
407
+ "spec": spec,
408
+ }
409
+
410
+
411
+ def _preload_datasets_for_fork() -> None:
412
+ # Importing datasets pulls in fsspec, which scans Python entry points.
413
+ # On this machine that scan can intermittently hit a corrupt/fragile zipped
414
+ # egg when many workers import at once. Preloading in the parent lets forked
415
+ # workers reuse sys.modules instead of racing through the entry point scan.
416
+ from datasets import Dataset, Features, Sequence, Value, disable_progress_bars, load_from_disk # noqa: F401
417
+
418
+ disable_progress_bars()
419
+
420
+
421
+ def _parallel_build(args: argparse.Namespace) -> dict:
422
+ from concurrent.futures import ProcessPoolExecutor, as_completed
423
+
424
+ specs = _make_limited_specs(args)
425
+ if not specs:
426
+ raise RuntimeError("No input file specs found")
427
+
428
+ output_dir = Path(args.output_dir)
429
+ parts_root = output_dir / "parts"
430
+ parts_root.mkdir(parents=True, exist_ok=True)
431
+
432
+ tasks = []
433
+ part_results = []
434
+ for idx, spec in enumerate(specs):
435
+ part_dir = parts_root / f"part-{idx:05d}"
436
+ if args.resume_parts and _part_is_complete(part_dir):
437
+ part_results.append(_summarize_part(part_dir, spec))
438
+ continue
439
+ tasks.append(
440
+ {
441
+ "part_dir": str(part_dir),
442
+ "spec": spec,
443
+ "tokenizer_path": args.tokenizer_path,
444
+ "text_column": args.text_column,
445
+ "detokenizer": args.detokenizer,
446
+ "max_len": int(args.max_len),
447
+ "packing_mode": args.packing_mode,
448
+ "min_len": int(args.min_len),
449
+ "add_eos": bool(args.add_eos),
450
+ "add_special_tokens": bool(args.add_special_tokens),
451
+ "tokenize_batch_size": int(args.tokenize_batch_size),
452
+ "cache_dir": str(Path(args.cache_dir) / f"part-{idx:05d}") if args.cache_dir else "",
453
+ "max_shard_size": args.max_shard_size,
454
+ }
455
+ )
456
+
457
+ print(
458
+ f"[build] specs={len(specs)} existing={len(part_results)} "
459
+ f"todo={len(tasks)} num_proc={args.num_proc} output={output_dir}",
460
+ flush=True,
461
+ )
462
+ if tasks:
463
+ _preload_datasets_for_fork()
464
+ with ProcessPoolExecutor(max_workers=max(1, int(args.num_proc))) as pool:
465
+ futures = [pool.submit(_build_part, task) for task in tasks]
466
+ for done, fut in enumerate(as_completed(futures), start=1):
467
+ result = fut.result()
468
+ part_results.append(result)
469
+ print(
470
+ "[build] "
471
+ f"{done}/{len(futures)} {Path(result['part_dir']).name} "
472
+ f"examples={result['num_examples']} tokens={result['total_tokens']}",
473
+ flush=True,
474
+ )
475
+
476
+ part_results.sort(key=lambda x: x["part_dir"])
477
+ total_examples = sum(int(x["num_examples"]) for x in part_results)
478
+ total_tokens = sum(int(x["total_tokens"]) for x in part_results)
479
+ meta = {
480
+ "builder": "build_owt_t5_elf_dataset.py",
481
+ "format": f"elf_unconditional_tokenized_{args.packing_mode}_multipart",
482
+ "data_path": args.data_path,
483
+ "tokenizer_path": args.tokenizer_path,
484
+ "text_column": args.text_column,
485
+ "openwebtext_split": args.openwebtext_split,
486
+ "openwebtext_valid_records": args.openwebtext_valid_records,
487
+ "max_len": args.max_len,
488
+ "packing_mode": args.packing_mode,
489
+ "max_records": args.max_records,
490
+ "min_len": args.min_len,
491
+ "add_eos": args.add_eos,
492
+ "add_special_tokens": args.add_special_tokens,
493
+ "num_parts": len(part_results),
494
+ "num_examples": int(total_examples),
495
+ "total_tokens": int(total_tokens),
496
+ "mean_length": float(total_tokens / total_examples) if total_examples else 0.0,
497
+ "parts": part_results,
498
+ }
499
+ (output_dir / "elf_multi_part_meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True), encoding="utf-8")
500
+
501
+ if args.merge_parts:
502
+ from datasets import concatenate_datasets, load_from_disk
503
+
504
+ merged_tmp = output_dir / "_merged_tmp"
505
+ if merged_tmp.exists():
506
+ shutil.rmtree(merged_tmp)
507
+ datasets = [load_from_disk(result["part_dir"]) for result in part_results if result["num_examples"] > 0]
508
+ merged = datasets[0] if len(datasets) == 1 else concatenate_datasets(datasets)
509
+ merged.save_to_disk(str(merged_tmp), max_shard_size=args.max_shard_size)
510
+ for child in list(output_dir.iterdir()):
511
+ if child.name in {"_merged_tmp", "parts"}:
512
+ continue
513
+ if child.is_dir():
514
+ shutil.rmtree(child)
515
+ else:
516
+ child.unlink()
517
+ for child in list(merged_tmp.iterdir()):
518
+ child.rename(output_dir / child.name)
519
+ merged_tmp.rmdir()
520
+ if not args.keep_parts:
521
+ shutil.rmtree(parts_root)
522
+ meta["format"] = f"elf_unconditional_tokenized_{args.packing_mode}"
523
+ (output_dir / "elf_build_meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True), encoding="utf-8")
524
+
525
+ return meta
526
+
527
+
528
+ def main() -> None:
529
+ args = parse_args()
530
+ output_dir = Path(args.output_dir)
531
+
532
+ if args.stats_only:
533
+ print(json.dumps(_stats(args), indent=2, sort_keys=True))
534
+ return
535
+
536
+ if output_dir.exists():
537
+ if not args.overwrite:
538
+ if not args.resume_parts:
539
+ raise SystemExit(f"output_dir exists: {output_dir}; pass --overwrite to replace it")
540
+ elif not args.resume_parts:
541
+ shutil.rmtree(output_dir)
542
+ output_dir.mkdir(parents=True, exist_ok=True)
543
+
544
+ if args.num_proc > 1:
545
+ meta = _parallel_build(args)
546
+ print(json.dumps({k: v for k, v in meta.items() if k != "parts"}, indent=2, sort_keys=True))
547
+ return
548
+
549
+ from datasets import Dataset, Features, Sequence, Value
550
+
551
+ features = Features(
552
+ {
553
+ "input_ids": Sequence(Value("int32")),
554
+ "sequence_length": Value("int64"),
555
+ }
556
+ )
557
+ ds = Dataset.from_generator(
558
+ _iter_examples,
559
+ gen_kwargs=_gen_kwargs(args),
560
+ features=features,
561
+ cache_dir=args.cache_dir or None,
562
+ )
563
+ ds.save_to_disk(str(output_dir), max_shard_size=args.max_shard_size)
564
+
565
+ meta = {
566
+ "builder": "build_owt_t5_elf_dataset.py",
567
+ "format": f"elf_unconditional_tokenized_{args.packing_mode}",
568
+ "data_path": args.data_path,
569
+ "tokenizer_path": args.tokenizer_path,
570
+ "text_column": args.text_column,
571
+ "openwebtext_split": args.openwebtext_split,
572
+ "openwebtext_valid_records": args.openwebtext_valid_records,
573
+ "max_len": args.max_len,
574
+ "packing_mode": args.packing_mode,
575
+ "max_records": args.max_records,
576
+ "min_len": args.min_len,
577
+ "add_eos": args.add_eos,
578
+ "add_special_tokens": args.add_special_tokens,
579
+ "num_examples": int(len(ds)),
580
+ "columns": list(ds.column_names),
581
+ }
582
+ (output_dir / "elf_build_meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True), encoding="utf-8")
583
+ print(json.dumps(meta, indent=2, sort_keys=True))
584
+
585
+
586
+ if __name__ == "__main__":
587
+ main()
LTA_openwebtext_dualt/scripts/eval_dirichlet_latest_key3_state_20260508.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import importlib.util
5
+ import sys
6
+ from pathlib import Path
7
+
8
+
9
+ BASE = Path(__file__).with_name("eval_c1024_decode_sweep_20260507.py")
10
+ spec = importlib.util.spec_from_file_location("eval_c1024_decode_sweep_20260507", BASE)
11
+ if spec is None or spec.loader is None:
12
+ raise RuntimeError(f"cannot import {BASE}")
13
+ base = importlib.util.module_from_spec(spec)
14
+ sys.modules[spec.name] = base
15
+ spec.loader.exec_module(base)
16
+
17
+
18
+ def key_configs() -> list[base.DecodeConfig]:
19
+ return [
20
+ base.DecodeConfig(
21
+ "match_post_sem1_state_c16_t1p3",
22
+ "post",
23
+ 1.0,
24
+ 1.0,
25
+ "state",
26
+ endpoint_temp=1.3,
27
+ concentration_max=16.0,
28
+ ),
29
+ base.DecodeConfig(
30
+ "match_post_sem1_state_c64_t1p3",
31
+ "post",
32
+ 1.0,
33
+ 1.0,
34
+ "state",
35
+ endpoint_temp=1.3,
36
+ concentration_max=64.0,
37
+ ),
38
+ base.DecodeConfig(
39
+ "match_post_sem1_state_c1024_t1p3",
40
+ "post",
41
+ 1.0,
42
+ 1.0,
43
+ "state",
44
+ endpoint_temp=1.3,
45
+ concentration_max=1024.0,
46
+ ),
47
+ ]
48
+
49
+
50
+ base.default_configs = key_configs
51
+ base.main()
LTA_openwebtext_dualt/scripts/infer_lta_owt_t5_len128_uniform10k_then_lognsr_latest.sh ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
7
+ export PYTHONUNBUFFERED=1
8
+ export TOKENIZERS_PARALLELISM=false
9
+
10
+ RUN_PREFIX="${RUN_PREFIX:-lta_owt_t5_len128_uniform10k_then_lognsr}"
11
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
12
+ SCORER="${SCORER:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-large-standard}"
13
+
14
+ N_SAMPLES="${N_SAMPLES:-8}"
15
+ DECODE_BATCH="${DECODE_BATCH:-4}"
16
+ SCORE_BATCH="${SCORE_BATCH:-4}"
17
+ MAX_LEN="${MAX_LEN:-128}"
18
+ STEPS="${STEPS:-1024}"
19
+ ENDPOINT_TEMPS="${ENDPOINT_TEMPS:-1.0,1.15,1.30,1.45}"
20
+
21
+ DECODE_RULE="${DECODE_RULE:-dirichlet_resample}"
22
+ MODEL_T_MODE="${MODEL_T_MODE:-post}"
23
+ TIME_SCHEDULE="${TIME_SCHEDULE:-lognsr_gumbel}"
24
+ TIME_GUMBEL_LOC="${TIME_GUMBEL_LOC:-2.2}"
25
+ TIME_GUMBEL_SCALE="${TIME_GUMBEL_SCALE:-0.8}"
26
+ CONCENTRATION_MIN="${CONCENTRATION_MIN:-1}"
27
+ CONCENTRATION_MAX="${CONCENTRATION_MAX:-64}"
28
+ NOISE_INIT="${NOISE_INIT:-dirichlet}"
29
+ FINAL_FROM="${FINAL_FROM:-state}"
30
+ FINAL_SAMPLE_MODE="${FINAL_SAMPLE_MODE:-argmax}"
31
+
32
+ pick_run() {
33
+ local suffix="$1"
34
+ find runs -maxdepth 1 -type d -name "${RUN_PREFIX}*${suffix}" -printf "%T@ %p\n" 2>/dev/null \
35
+ | sort -nr \
36
+ | head -n 1 \
37
+ | cut -d' ' -f2-
38
+ }
39
+
40
+ RUN_DIR="${RUN_DIR:-}"
41
+ if [[ -z "${RUN_DIR}" ]]; then
42
+ RUN_DIR="$(pick_run "_resume_lognsr_sde_rollin")"
43
+ fi
44
+ if [[ -z "${RUN_DIR}" ]]; then
45
+ RUN_DIR="$(pick_run "_warmup_uniform_norollin")"
46
+ fi
47
+ if [[ -z "${RUN_DIR}" || ! -d "${RUN_DIR}" ]]; then
48
+ echo "[infer] could not find run dir for prefix=${RUN_PREFIX}" >&2
49
+ exit 1
50
+ fi
51
+
52
+ CKPT="${CKPT:-}"
53
+ if [[ -z "${CKPT}" ]]; then
54
+ CKPT="$(ls -1 "${RUN_DIR}"/step_*.pt 2>/dev/null | sort | tail -n 1 || true)"
55
+ fi
56
+ if [[ -z "${CKPT}" || ! -f "${CKPT}" ]]; then
57
+ echo "[infer] could not find checkpoint under ${RUN_DIR}" >&2
58
+ exit 1
59
+ fi
60
+
61
+ RUN_BASENAME="$(basename "${RUN_DIR}")"
62
+ CKPT_BASENAME="$(basename "${CKPT}" .pt)"
63
+ OUT_DIR="${OUT_DIR:-docs/lta_samples/metrics_20260519/${RUN_BASENAME}_${CKPT_BASENAME}_len128_lm1bgood_sdeish_n${N_SAMPLES}}"
64
+ OUT_JSONL="${OUT_DIR}/summary.jsonl"
65
+ mkdir -p "${OUT_DIR}"
66
+
67
+ echo "[infer] run=${RUN_DIR}"
68
+ echo "[infer] ckpt=${CKPT}"
69
+ echo "[infer] out=${OUT_JSONL}"
70
+ echo "[infer] decode_rule=${DECODE_RULE} steps=${STEPS} cmax=${CONCENTRATION_MAX} model_t=${MODEL_T_MODE} temps=${ENDPOINT_TEMPS}"
71
+
72
+ python scripts/standard_genppl_entropy_latest_decode.py \
73
+ --checkpoint "${CKPT}" \
74
+ --tokenizer_path "${TOKENIZER_PATH}" \
75
+ --scorer "${SCORER}" \
76
+ --output "${OUT_JSONL}" \
77
+ --max_len "${MAX_LEN}" \
78
+ --n_samples "${N_SAMPLES}" \
79
+ --decode_batch "${DECODE_BATCH}" \
80
+ --score_batch "${SCORE_BATCH}" \
81
+ --score_max_length "${MAX_LEN}" \
82
+ --steps "${STEPS}" \
83
+ --model_t_mode "${MODEL_T_MODE}" \
84
+ --decode_time_schedule "${TIME_SCHEDULE}" \
85
+ --decode_time_gumbel_loc "${TIME_GUMBEL_LOC}" \
86
+ --decode_time_gumbel_scale "${TIME_GUMBEL_SCALE}" \
87
+ --decode_rule "${DECODE_RULE}" \
88
+ --concentration_min "${CONCENTRATION_MIN}" \
89
+ --concentration_max "${CONCENTRATION_MAX}" \
90
+ --noise_init "${NOISE_INIT}" \
91
+ --endpoint_temps "${ENDPOINT_TEMPS}" \
92
+ --final_from "${FINAL_FROM}" \
93
+ --final_sample_mode "${FINAL_SAMPLE_MODE}" \
94
+ --save_samples "${N_SAMPLES}"
95
+
96
+ echo "[infer] summaries:"
97
+ python - "${OUT_JSONL}" <<'PY'
98
+ import json, sys
99
+ path = sys.argv[1]
100
+ with open(path, encoding="utf-8") as f:
101
+ for line in f:
102
+ row = json.loads(line)
103
+ if row.get("type") != "summary":
104
+ continue
105
+ d = row["decode"]
106
+ stripped = row.get("stripped_genppl", {})
107
+ div = row.get("diversity", {})
108
+ print(
109
+ f"temp={d['endpoint_temp']:.2f} final={d['final_from']} "
110
+ f"ppl={stripped.get('ppl')} entropy={div.get('sample_entropy')} "
111
+ f"top_mass={div.get('top_token_mass')}"
112
+ )
113
+ PY
LTA_openwebtext_dualt/scripts/launch_lta_lm1b_categorical_fullvocab_c1024_fullycoupled_8gpu_small_1m.sh ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
6
+ export TOKENIZERS_PARALLELISM=false
7
+ export PYTHONUNBUFFERED=1
8
+ export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
9
+ export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
10
+
11
+ # Fully-coupled t ablation:
12
+ # model_t == support/Dirichlet t == semantic endpoint t
13
+ RUN_NAME="${RUN_NAME:-lta_lm1b_dirichlet_categorical_fullvocab_c1024_fullycoupled_flmpack_onehot_hardce_ddit_small_len128_gbs512_8gpu_1m_nw0}"
14
+ DATA_PATH="${DATA_PATH:-data/lm1b_train_parquet}"
15
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/workspace/imagenet_handoff_20260327/nlp_dts_light/assets/distilbert-base-uncased/tokenizer.json}"
16
+ TEXT_COLUMN="${TEXT_COLUMN:-}"
17
+ OPENWEBTEXT_SPLIT="${OPENWEBTEXT_SPLIT:-all}"
18
+ SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
19
+ LOG_FILE="${LOG_FILE:-logs/${RUN_NAME}.log}"
20
+
21
+ NNODES="${NNODES:-1}"
22
+ NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
23
+ NODE_RANK="${NODE_RANK:-0}"
24
+ MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
25
+ MASTER_PORT="${MASTER_PORT:-29631}"
26
+
27
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
28
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-64}"
29
+ TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
30
+ WARMUP_STEPS="${WARMUP_STEPS:-2500}"
31
+ MAX_LEN="${MAX_LEN:-128}"
32
+ WRAP_MODE="${WRAP_MODE:-stream}"
33
+ WRAP_RECORD_BUFFER_SIZE="${WRAP_RECORD_BUFFER_SIZE:-200}"
34
+ NUM_WORKERS="${NUM_WORKERS:-0}"
35
+ LOG_EVERY="${LOG_EVERY:-100}"
36
+ SAVE_EVERY="${SAVE_EVERY:-20000}"
37
+ LATEST_EVERY="${LATEST_EVERY:-1000}"
38
+ EVAL_EVERY="${EVAL_EVERY:-0}"
39
+ RESUME_PATH="${RESUME_PATH:-}"
40
+ ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
41
+ ENABLE_TORCH_COMPILE="${ENABLE_TORCH_COMPILE:-0}"
42
+ FORCE_DISABLE_TORCH_COMPILE="${FORCE_DISABLE_TORCH_COMPILE:-1}"
43
+
44
+ if [[ "${FORCE_DISABLE_TORCH_COMPILE}" == "1" ]]; then
45
+ ENABLE_TORCH_COMPILE=0
46
+ fi
47
+ if [[ "${DATA_PATH}" == *"lm1b_train_parquet"* && "${NUM_WORKERS}" != "0" ]]; then
48
+ echo "LM1B has only 9 parquet shards; forcing NUM_WORKERS=0 to avoid empty DDP dataloader shards." >&2
49
+ NUM_WORKERS=0
50
+ fi
51
+
52
+ COMPILE_ARGS=()
53
+ if [[ "${ENABLE_TORCH_COMPILE}" == "1" ]]; then
54
+ COMPILE_ARGS+=(--torch_compile --compile_mode reduce-overhead)
55
+ fi
56
+ RESUME_ARGS=()
57
+ if [[ -n "${RESUME_PATH}" ]]; then
58
+ RESUME_ARGS+=(--resume_path "${RESUME_PATH}")
59
+ fi
60
+ TEXT_COLUMN_ARGS=()
61
+ if [[ -n "${TEXT_COLUMN}" ]]; then
62
+ TEXT_COLUMN_ARGS+=(--text_column "${TEXT_COLUMN}")
63
+ fi
64
+
65
+ if [[ -f "${SAVE_DIR}/args.json" && -z "${RESUME_PATH}" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
66
+ echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
67
+ echo "Use a new RUN_NAME/SAVE_DIR, set RESUME_PATH to resume, or set ALLOW_EXISTING_SAVE_DIR=1 intentionally." >&2
68
+ exit 2
69
+ fi
70
+
71
+ mkdir -p logs runs "${SAVE_DIR}"
72
+ echo "[launch] method=categorical_fullvocab_c1024_fullycoupled host=$(hostname) time=$(date -Iseconds)"
73
+ echo "[launch] cwd=$(pwd)"
74
+ echo "[launch] run_name=${RUN_NAME}"
75
+ echo "[launch] save_dir=${SAVE_DIR}"
76
+ echo "[launch] log_file=${LOG_FILE}"
77
+
78
+ python -m torch.distributed.run \
79
+ --nnodes="${NNODES}" \
80
+ --nproc_per_node="${NPROC_PER_NODE}" \
81
+ --node_rank="${NODE_RANK}" \
82
+ --master_addr="${MASTER_ADDR}" \
83
+ --master_port="${MASTER_PORT}" \
84
+ train.py \
85
+ --data_path "${DATA_PATH}" \
86
+ "${TEXT_COLUMN_ARGS[@]}" \
87
+ --openwebtext_split "${OPENWEBTEXT_SPLIT}" \
88
+ --tokenizer_path "${TOKENIZER_PATH}" \
89
+ --save_dir "${SAVE_DIR}" \
90
+ --wrap \
91
+ --wrap_mode "${WRAP_MODE}" \
92
+ --wrap_record_buffer_size "${WRAP_RECORD_BUFFER_SIZE}" \
93
+ --max_len "${MAX_LEN}" \
94
+ --batch_size "${PER_GPU_BATCH_SIZE}" \
95
+ --num_workers "${NUM_WORKERS}" \
96
+ --global_batch_size "${GLOBAL_BATCH_SIZE}" \
97
+ --total_steps "${TOTAL_STEPS}" \
98
+ --log_every "${LOG_EVERY}" \
99
+ --eval_every "${EVAL_EVERY}" \
100
+ --save_every "${SAVE_EVERY}" \
101
+ --latest_every "${LATEST_EVERY}" \
102
+ --lr 3e-4 \
103
+ --weight_decay 0 \
104
+ --adam_beta1 0.9 \
105
+ --adam_beta2 0.999 \
106
+ --adam_eps 1e-8 \
107
+ --warmup_steps "${WARMUP_STEPS}" \
108
+ --lr_schedule constant_warmup \
109
+ --grad_clip 1.0 \
110
+ --seed 123 \
111
+ --d_model 768 \
112
+ --cond_dim 128 \
113
+ --n_layers 12 \
114
+ --n_heads 12 \
115
+ --dim_ff 3072 \
116
+ --dropout 0.1 \
117
+ --model_type ddit \
118
+ --state_format prob \
119
+ --bridge dirichlet \
120
+ --target_loss hard_ce \
121
+ --target_prob 1.0 \
122
+ --min_t 0.0 \
123
+ --max_t 1.0 \
124
+ --dual_t \
125
+ --corrupt_t_mode same \
126
+ --corrupt_min_t 0.0 \
127
+ --corrupt_max_t 1.0 \
128
+ --min_mask_ratio 0.1 \
129
+ --max_mask_ratio 1.0 \
130
+ --wrong_token_replace_prob 1.0 \
131
+ --wrong_token_schedule linear_t \
132
+ --wrong_token_exp_k 1.0 \
133
+ --dirichlet_concentration_min 1.0 \
134
+ --dirichlet_concentration_max 1024.0 \
135
+ --dirichlet_endpoint_mode categorical_dual_t \
136
+ --dirichlet_semantic_t_mode same \
137
+ --dirichlet_semantic_t_value 0.0 \
138
+ --categorical_wrong_from_full_vocab \
139
+ --simplex_bridge_sampler dirichlet \
140
+ --eps 1e-8 \
141
+ --infer_steps 128 \
142
+ --decode_damping 1.0 \
143
+ --max_gamma 1.0 \
144
+ --decode_solver flowmap \
145
+ --noise_init logistic_normal \
146
+ --bridge_noise_init logistic_normal \
147
+ --noise_sigma -1 \
148
+ "${RESUME_ARGS[@]}" \
149
+ "${COMPILE_ARGS[@]}" \
150
+ --bf16 2>&1 | tee -a "${LOG_FILE}"
LTA_openwebtext_dualt/scripts/launch_lta_lm1b_categorical_fullvocab_c16_dualt_4gpu_small_1m.sh ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
6
+ export TOKENIZERS_PARALLELISM=false
7
+ export PYTHONUNBUFFERED=1
8
+ export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
9
+ export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
10
+
11
+ # C=16 categorical dual-t LM1B, full-vocab wrong-token endpoint.
12
+ # This is the 4-GPU counterpart of the 8-GPU full-vocab run; global batch stays 512.
13
+
14
+ C_MAX="${C_MAX:-16.0}"
15
+ C_TAG="${C_TAG:-c${C_MAX//./p}}"
16
+ RUN_NAME="${RUN_NAME:-lta_lm1b_dirichlet_categorical_fullvocab_${C_TAG}_dualt_flmpack_onehot_hardce_ddit_small_len128_gbs512_4gpu_1m_nw0}"
17
+ DATA_PATH="${DATA_PATH:-data/lm1b_train_parquet}"
18
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/workspace/imagenet_handoff_20260327/nlp_dts_light/assets/distilbert-base-uncased/tokenizer.json}"
19
+ DETOKENIZER="${DETOKENIZER:-auto}"
20
+ TEXT_COLUMN="${TEXT_COLUMN:-}"
21
+ OPENWEBTEXT_SPLIT="${OPENWEBTEXT_SPLIT:-all}"
22
+ SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
23
+ LOG_FILE="${LOG_FILE:-logs/${RUN_NAME}.log}"
24
+
25
+ NNODES="${NNODES:-1}"
26
+ NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
27
+ NODE_RANK="${NODE_RANK:-0}"
28
+ MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
29
+ MASTER_PORT="${MASTER_PORT:-29641}"
30
+
31
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
32
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-64}"
33
+ TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
34
+ WARMUP_STEPS="${WARMUP_STEPS:-2500}"
35
+ MAX_LEN="${MAX_LEN:-128}"
36
+ WRAP_MODE="${WRAP_MODE:-stream}"
37
+ WRAP_RECORD_BUFFER_SIZE="${WRAP_RECORD_BUFFER_SIZE:-200}"
38
+ NUM_WORKERS="${NUM_WORKERS:-0}"
39
+ LOG_EVERY="${LOG_EVERY:-100}"
40
+ SAVE_EVERY="${SAVE_EVERY:-20000}"
41
+ LATEST_EVERY="${LATEST_EVERY:-1000}"
42
+ EVAL_EVERY="${EVAL_EVERY:-0}"
43
+ RESUME_PATH="${RESUME_PATH:-}"
44
+ ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
45
+ ENABLE_TORCH_COMPILE="${ENABLE_TORCH_COMPILE:-0}"
46
+ FORCE_DISABLE_TORCH_COMPILE="${FORCE_DISABLE_TORCH_COMPILE:-1}"
47
+
48
+ if [[ "${FORCE_DISABLE_TORCH_COMPILE}" == "1" ]]; then
49
+ ENABLE_TORCH_COMPILE=0
50
+ fi
51
+ if [[ "${DATA_PATH}" == *"lm1b_train_parquet"* && "${NUM_WORKERS}" != "0" ]]; then
52
+ echo "LM1B has only 9 parquet shards; forcing NUM_WORKERS=0 to avoid empty DDP dataloader shards." >&2
53
+ NUM_WORKERS=0
54
+ fi
55
+
56
+ COMPILE_ARGS=()
57
+ if [[ "${ENABLE_TORCH_COMPILE}" == "1" ]]; then
58
+ COMPILE_ARGS+=(--torch_compile --compile_mode reduce-overhead)
59
+ fi
60
+ RESUME_ARGS=()
61
+ if [[ -n "${RESUME_PATH}" ]]; then
62
+ RESUME_ARGS+=(--resume_path "${RESUME_PATH}")
63
+ fi
64
+ TEXT_COLUMN_ARGS=()
65
+ if [[ -n "${TEXT_COLUMN}" ]]; then
66
+ TEXT_COLUMN_ARGS+=(--text_column "${TEXT_COLUMN}")
67
+ fi
68
+
69
+ if [[ -f "${SAVE_DIR}/args.json" && -z "${RESUME_PATH}" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
70
+ echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
71
+ echo "Use a new RUN_NAME/SAVE_DIR, set RESUME_PATH to resume, or set ALLOW_EXISTING_SAVE_DIR=1 intentionally." >&2
72
+ exit 2
73
+ fi
74
+
75
+ mkdir -p logs runs "${SAVE_DIR}"
76
+ echo "[launch] method=categorical_fullvocab C_MAX=${C_MAX} host=$(hostname) time=$(date -Iseconds)"
77
+ echo "[launch] cwd=$(pwd)"
78
+ echo "[launch] run_name=${RUN_NAME}"
79
+ echo "[launch] save_dir=${SAVE_DIR}"
80
+ echo "[launch] log_file=${LOG_FILE}"
81
+ echo "[launch] nproc_per_node=${NPROC_PER_NODE} global_batch_size=${GLOBAL_BATCH_SIZE} per_gpu_batch_size=${PER_GPU_BATCH_SIZE}"
82
+
83
+ python -m torch.distributed.run \
84
+ --nnodes="${NNODES}" \
85
+ --nproc_per_node="${NPROC_PER_NODE}" \
86
+ --node_rank="${NODE_RANK}" \
87
+ --master_addr="${MASTER_ADDR}" \
88
+ --master_port="${MASTER_PORT}" \
89
+ train.py \
90
+ --data_path "${DATA_PATH}" \
91
+ "${TEXT_COLUMN_ARGS[@]}" \
92
+ --openwebtext_split "${OPENWEBTEXT_SPLIT}" \
93
+ --detokenizer "${DETOKENIZER}" \
94
+ --tokenizer_path "${TOKENIZER_PATH}" \
95
+ --save_dir "${SAVE_DIR}" \
96
+ --wrap \
97
+ --wrap_mode "${WRAP_MODE}" \
98
+ --wrap_record_buffer_size "${WRAP_RECORD_BUFFER_SIZE}" \
99
+ --max_len "${MAX_LEN}" \
100
+ --batch_size "${PER_GPU_BATCH_SIZE}" \
101
+ --num_workers "${NUM_WORKERS}" \
102
+ --global_batch_size "${GLOBAL_BATCH_SIZE}" \
103
+ --total_steps "${TOTAL_STEPS}" \
104
+ --log_every "${LOG_EVERY}" \
105
+ --eval_every "${EVAL_EVERY}" \
106
+ --save_every "${SAVE_EVERY}" \
107
+ --latest_every "${LATEST_EVERY}" \
108
+ --lr 3e-4 \
109
+ --weight_decay 0 \
110
+ --adam_beta1 0.9 \
111
+ --adam_beta2 0.999 \
112
+ --adam_eps 1e-8 \
113
+ --warmup_steps "${WARMUP_STEPS}" \
114
+ --lr_schedule constant_warmup \
115
+ --grad_clip 1.0 \
116
+ --seed 123 \
117
+ --d_model 768 \
118
+ --cond_dim 128 \
119
+ --n_layers 12 \
120
+ --n_heads 12 \
121
+ --dim_ff 3072 \
122
+ --dropout 0.1 \
123
+ --model_type ddit \
124
+ --state_format prob \
125
+ --bridge dirichlet \
126
+ --target_loss hard_ce \
127
+ --target_prob 1.0 \
128
+ --min_t 0.0 \
129
+ --max_t 1.0 \
130
+ --dual_t \
131
+ --corrupt_t_mode independent \
132
+ --corrupt_min_t 0.0 \
133
+ --corrupt_max_t 1.0 \
134
+ --min_mask_ratio 0.1 \
135
+ --max_mask_ratio 1.0 \
136
+ --wrong_token_replace_prob 1.0 \
137
+ --wrong_token_schedule linear_t \
138
+ --wrong_token_exp_k 1.0 \
139
+ --dirichlet_concentration_min 1.0 \
140
+ --dirichlet_concentration_max "${C_MAX}" \
141
+ --dirichlet_endpoint_mode categorical_dual_t \
142
+ --dirichlet_semantic_t_mode independent \
143
+ --dirichlet_semantic_t_value 0.0 \
144
+ --categorical_wrong_from_full_vocab \
145
+ --eps 1e-8 \
146
+ --infer_steps 128 \
147
+ --decode_damping 1.0 \
148
+ --max_gamma 1.0 \
149
+ --decode_solver flowmap \
150
+ --noise_init logistic_normal \
151
+ --bridge_noise_init logistic_normal \
152
+ --noise_sigma -1 \
153
+ "${RESUME_ARGS[@]}" \
154
+ "${COMPILE_ARGS[@]}" \
155
+ --bf16 2>&1 | tee -a "${LOG_FILE}"
LTA_openwebtext_dualt/scripts/launch_lta_owt_c1024_fullycoupled_8gpu_len1024_gpt2_cached_chunks_1m.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ # Explicit cached-chunk OWT/GPT-2 run.
7
+ # Uses the already-built cache:
8
+ # openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k
9
+ #
10
+ # Data processing:
11
+ # tokenize records with GPT-2 tokenizer
12
+ # append GPT-2 EOT after each record
13
+ # concatenate stream
14
+ # split into payload_len=1022
15
+ # wrap as [EOT] + payload + [EOT]
16
+ # train from fixed memmap chunks with DistributedSampler shuffle
17
+
18
+ export OWT_CACHED_CHUNKS=1
19
+ export OWT_CHUNK_CACHE_DIR="${OWT_CHUNK_CACHE_DIR:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k}"
20
+ # Default to reusing the prebuilt cache; set OWT_CHUNK_CACHE_REBUILD=1 only when
21
+ # intentionally refreshing or repairing the cached chunk pool.
22
+ export OWT_CHUNK_CACHE_REBUILD="${OWT_CHUNK_CACHE_REBUILD:-0}"
23
+
24
+ export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
25
+ export PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
26
+ export GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
27
+ export TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
28
+ export WARMUP_STEPS="${WARMUP_STEPS:-2000}"
29
+ export MAX_LEN="${MAX_LEN:-1024}"
30
+ export NUM_WORKERS="${NUM_WORKERS:-4}"
31
+ export DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-2}"
32
+ export LOG_EVERY="${LOG_EVERY:-100}"
33
+ export SAVE_EVERY="${SAVE_EVERY:-20000}"
34
+ export LATEST_EVERY="${LATEST_EVERY:-1000}"
35
+ export EVAL_EVERY="${EVAL_EVERY:-0}"
36
+ export ENABLE_TORCH_COMPILE="${ENABLE_TORCH_COMPILE:-0}"
37
+ export ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
38
+ export OPTIMIZER="${OPTIMIZER:-adamw}"
39
+ export MUON_MOMENTUM="${MUON_MOMENTUM:-0.95}"
40
+ export MUON_NS_STEPS="${MUON_NS_STEPS:-5}"
41
+ export MUON_UPDATE_SCALE="${MUON_UPDATE_SCALE:-1.0}"
42
+ export EMA_DECAY="${EMA_DECAY:-0.0}"
43
+ export EMA_START_STEP="${EMA_START_STEP:-0}"
44
+ export ALLOW_TF32="${ALLOW_TF32:-1}"
45
+ export ACTIVATION_CHECKPOINTING="${ACTIVATION_CHECKPOINTING:-0}"
46
+ export ACTIVATION_CHECKPOINT_INTERVAL="${ACTIVATION_CHECKPOINT_INTERVAL:-1}"
47
+ export DDP_STATIC_GRAPH="${DDP_STATIC_GRAPH:-0}"
48
+ export DDP_GRADIENT_AS_BUCKET_VIEW="${DDP_GRADIENT_AS_BUCKET_VIEW:-1}"
49
+ export BLOCKING_DATA_TRANSFER="${BLOCKING_DATA_TRANSFER:-0}"
50
+ export FULL_TRAIN_STATS="${FULL_TRAIN_STATS:-0}"
51
+
52
+ export DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext}"
53
+ export TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
54
+ export TEXT_COLUMN="${TEXT_COLUMN:-text}"
55
+ export OPENWEBTEXT_SPLIT="${OPENWEBTEXT_SPLIT:-train_minus_100k}"
56
+ export DETOKENIZER="${DETOKENIZER:-auto}"
57
+
58
+ export RUN_NAME="${RUN_NAME:-lta_owt_dirichlet_categorical_fullvocab_c1024_fullycoupled_gpt2_cached_chunks_len1024_gbs${GLOBAL_BATCH_SIZE}_${NPROC_PER_NODE}gpu_1m_nw${NUM_WORKERS}}"
59
+
60
+ bash scripts/launch_lta_owt_categorical_fullvocab_c1024_fullycoupled_8gpu_small_1m.sh
LTA_openwebtext_dualt/scripts/launch_lta_owt_compact_gpt2bpe_v8192_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ # 8k compact GPT2-BPE variant of the v2048 fully-coupled mask=1 baseline.
7
+ # Keep the actual training recipe centralized in the v2048 script; this wrapper
8
+ # only swaps tokenizer/data/run labels.
9
+ export VOCAB_SIZE="${VOCAB_SIZE:-8192}"
10
+ export DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-compact-gpt2bpe-v8192-stream1024-train-minus-100k}"
11
+ export TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/lta_tokenizers/owt_compact_gpt2bpe_v8192/tokenizer.json}"
12
+ export COMPACT_VARIANT_LABEL="${COMPACT_VARIANT_LABEL:-compact_gpt2bpe_v8192_stream1024_fullycoupled_mask0p1-1p0_wd0p1_fp32}"
13
+ export T_SAMPLING_MODE="${T_SAMPLING_MODE:-uniform}"
14
+ export MIN_MASK_RATIO="${MIN_MASK_RATIO:-0.1}"
15
+ export MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
16
+
17
+ sanitize_label() {
18
+ printf "%s" "$1" | sed -e 's/-/m/g' -e 's/\./p/g'
19
+ }
20
+
21
+ T_SAMPLING_LOGIT_MEAN_FOR_NAME="${T_SAMPLING_LOGIT_MEAN:--1.5}"
22
+ T_SAMPLING_LOGIT_STD_FOR_NAME="${T_SAMPLING_LOGIT_STD:-0.8}"
23
+ MIN_MASK_RATIO_FOR_NAME="${MIN_MASK_RATIO:-1.0}"
24
+ MAX_MASK_RATIO_FOR_NAME="${MAX_MASK_RATIO:-1.0}"
25
+
26
+ T_LOGIT_MEAN_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_MEAN_FOR_NAME}")"
27
+ T_LOGIT_STD_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_STD_FOR_NAME}")"
28
+ MIN_MASK_RATIO_LABEL="$(sanitize_label "${MIN_MASK_RATIO_FOR_NAME}")"
29
+ MAX_MASK_RATIO_LABEL="$(sanitize_label "${MAX_MASK_RATIO_FOR_NAME}")"
30
+ if [[ "${T_SAMPLING_MODE}" == "logit_normal" ]]; then
31
+ T_SAMPLING_LABEL="logitnormal_${T_LOGIT_MEAN_LABEL}_s${T_LOGIT_STD_LABEL}"
32
+ else
33
+ T_SAMPLING_LABEL="$(sanitize_label "${T_SAMPLING_MODE}")t"
34
+ fi
35
+
36
+ export RUN_NAME="${RUN_NAME:-lta_owt_compact_gpt2bpe_v8192_stream1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_${T_SAMPLING_LABEL}_hardce_mask${MIN_MASK_RATIO_LABEL}-${MAX_MASK_RATIO_LABEL}_fp32_ddit768x12_gbs512_8gpu_1m_$(date +%Y%m%d_%H%M%S)}"
37
+ export LOG_DIR="${LOG_DIR:-logs/compact_gpt2bpe_v8192_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu}"
38
+
39
+ bash scripts/launch_lta_owt_compact_gpt2bpe_v2048_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu.sh
LTA_openwebtext_dualt/scripts/launch_lta_owt_elfaligned_t5_logitnormal_8gpu.sh ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
7
+ export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
8
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
9
+ export TOKENIZERS_PARALLELISM=false
10
+ export PYTHONUNBUFFERED=1
11
+ export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
12
+ export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
13
+
14
+ # ELF-aligned simplex run:
15
+ # architecture: ddit_elf = no adaLN, prefix time tokens, qk norm, RoPE, RMSNorm, SwiGLU
16
+ # tokenizer/data: T5-small tokenizer, one OWT record per example, pad/truncate to 1024
17
+ # optimizer: Muon, lr 0.002, wd 0, constant LR after 0.5 epoch warmup
18
+ # time sampling: sigmoid(N(T_LOGIT_MEAN, T_LOGIT_STD^2)); defaults match ELF
19
+ # The old ddit path and GPT2 cached scripts are untouched.
20
+
21
+ DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext}"
22
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
23
+
24
+ NNODES="${NNODES:-1}"
25
+ NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
26
+ NODE_RANK="${NODE_RANK:-0}"
27
+ MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
28
+ MASTER_PORT="${MASTER_PORT:-32091}"
29
+
30
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
31
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
32
+ EPOCHS="${EPOCHS:-5}"
33
+ NUM_WORKERS="${NUM_WORKERS:-8}"
34
+ DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
35
+ LOG_EVERY="${LOG_EVERY:-100}"
36
+ LATEST_EVERY="${LATEST_EVERY:-1000}"
37
+ EVAL_EVERY="${EVAL_EVERY:-0}"
38
+ ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
39
+ ALLOW_TF32="${ALLOW_TF32:-1}"
40
+
41
+ LR="${LR:-0.002}"
42
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.0}"
43
+ ADAM_BETA1="${ADAM_BETA1:-0.9}"
44
+ ADAM_BETA2="${ADAM_BETA2:-0.999}"
45
+ ADAM_EPS="${ADAM_EPS:-1e-8}"
46
+ MUON_MOMENTUM="${MUON_MOMENTUM:-0.95}"
47
+ MUON_NS_STEPS="${MUON_NS_STEPS:-5}"
48
+ MUON_UPDATE_SCALE="${MUON_UPDATE_SCALE:-1.0}"
49
+ GRAD_CLIP="${GRAD_CLIP:-1.0}"
50
+ EMA_DECAY="${EMA_DECAY:-0.9999}"
51
+ EMA_START_STEP="${EMA_START_STEP:-0}"
52
+ T_LOGIT_MEAN="${T_LOGIT_MEAN:--1.5}"
53
+ T_LOGIT_STD="${T_LOGIT_STD:-0.8}"
54
+ LOSS_T_WEIGHT_MODE="${LOSS_T_WEIGHT_MODE:-none}"
55
+ LOSS_T_MIN_WEIGHT="${LOSS_T_MIN_WEIGHT:-0.0}"
56
+ OUTPUT_INIT_STD="${OUTPUT_INIT_STD:-0.0}"
57
+
58
+ sanitize_label() {
59
+ printf "%s" "$1" | sed -e 's/-/m/g' -e 's/\./p/g'
60
+ }
61
+
62
+ T_LOGIT_MEAN_LABEL="$(sanitize_label "${T_LOGIT_MEAN}")"
63
+ T_LOGIT_STD_LABEL="$(sanitize_label "${T_LOGIT_STD}")"
64
+ LOSS_T_MIN_WEIGHT_LABEL="$(sanitize_label "${LOSS_T_MIN_WEIGHT}")"
65
+
66
+ RUN_NAME="${RUN_NAME:-lta_owt_t5record_len1024_elfaligned_dditelf_muon_logitnormal_${T_LOGIT_MEAN_LABEL}_s${T_LOGIT_STD_LABEL}_${LOSS_T_WEIGHT_MODE}_floor${LOSS_T_MIN_WEIGHT_LABEL}_gbs512_8gpu_5epoch_$(date +%Y%m%d_%H%M%S)}"
67
+ SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
68
+ LOG_DIR="${LOG_DIR:-logs/elfaligned_t5record_8gpu}"
69
+ LOG_FILE="${LOG_FILE:-${LOG_DIR}/${RUN_NAME}.log}"
70
+
71
+ NUM_RECORDS=$(python - <<PY
72
+ from pathlib import Path
73
+ import pyarrow.parquet as pq
74
+ root = Path("${DATA_PATH}")
75
+ files = sorted(root.rglob("*.parquet")) if root.is_dir() else [root]
76
+ rows = sum(pq.ParquetFile(str(p)).metadata.num_rows for p in files)
77
+ print(max(0, rows - 100_000))
78
+ PY
79
+ )
80
+ STEPS_PER_EPOCH=$(( (NUM_RECORDS + GLOBAL_BATCH_SIZE - 1) / GLOBAL_BATCH_SIZE ))
81
+ SAVE_EVERY="${SAVE_EVERY:-${STEPS_PER_EPOCH}}"
82
+
83
+ if [[ -f "${SAVE_DIR}/args.json" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
84
+ echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
85
+ echo "Use a new RUN_NAME/SAVE_DIR or set ALLOW_EXISTING_SAVE_DIR=1 intentionally." >&2
86
+ exit 2
87
+ fi
88
+
89
+ mkdir -p "${LOG_DIR}" "${SAVE_DIR}"
90
+
91
+ TF32_FLAG="--allow_tf32"
92
+ TF32_LABEL="true"
93
+ if [[ "${ALLOW_TF32}" == "0" || "${ALLOW_TF32}" == "false" || "${ALLOW_TF32}" == "False" ]]; then
94
+ TF32_FLAG="--no-allow_tf32"
95
+ TF32_LABEL="false"
96
+ fi
97
+
98
+ echo "[launch] method=owt_elfaligned_t5record_dditelf host=$(hostname) time=$(date -Iseconds)"
99
+ echo "[launch] run_name=${RUN_NAME}"
100
+ echo "[launch] save_dir=${SAVE_DIR}"
101
+ echo "[launch] log_file=${LOG_FILE}"
102
+ echo "[launch] data_path=${DATA_PATH}"
103
+ echo "[launch] tokenizer=${TOKENIZER_PATH}"
104
+ echo "[launch] records=${NUM_RECORDS} epochs=${EPOCHS} approx_steps_per_epoch=${STEPS_PER_EPOCH} save_every=${SAVE_EVERY}"
105
+ echo "[launch] optimizer=muon_impl=optax grouping=hidden_2d lr=${LR} wd=${WEIGHT_DECAY} adam_fallback_wd=0 momentum=${MUON_MOMENTUM} ns=${MUON_NS_STEPS} nesterov=true width_scale=true adam_fallback_b2=${ADAM_BETA2} ema=${EMA_DECAY}"
106
+ echo "[launch] model=ddit_elf rmsnorm qk_norm=true swiglu no_adaln output_bias=false output_init_std=${OUTPUT_INIT_STD} time_tokens=4 mode_tokens=0"
107
+ echo "[launch] data=record_pad_truncate pad=pad add_special_tokens=false t5-small fp32=true bf16=false tf32=${TF32_LABEL}"
108
+ echo "[launch] t_sampling=logit_normal mean=${T_LOGIT_MEAN} std=${T_LOGIT_STD} loss_t_weight=${LOSS_T_WEIGHT_MODE} loss_t_min_weight=${LOSS_T_MIN_WEIGHT} warmup_epochs=0.5"
109
+
110
+ python -m torch.distributed.run \
111
+ --nnodes="${NNODES}" \
112
+ --nproc_per_node="${NPROC_PER_NODE}" \
113
+ --node_rank="${NODE_RANK}" \
114
+ --master_addr="${MASTER_ADDR}" \
115
+ --master_port="${MASTER_PORT}" \
116
+ train.py \
117
+ --data_path "${DATA_PATH}" \
118
+ --openwebtext_split train_minus_100k \
119
+ --text_column text \
120
+ --detokenizer auto \
121
+ --tokenizer_path "${TOKENIZER_PATH}" \
122
+ --save_dir "${SAVE_DIR}" \
123
+ --record_pad_truncate \
124
+ --record_pad_token pad \
125
+ --record_shuffle_buffer 10000 \
126
+ --max_len 1024 \
127
+ --batch_size "${PER_GPU_BATCH_SIZE}" \
128
+ --global_batch_size "${GLOBAL_BATCH_SIZE}" \
129
+ --num_workers "${NUM_WORKERS}" \
130
+ --dataloader_prefetch_factor "${DATALOADER_PREFETCH_FACTOR}" \
131
+ --epochs "${EPOCHS}" \
132
+ --total_steps 1 \
133
+ --warmup_epochs 0.5 \
134
+ --log_every "${LOG_EVERY}" \
135
+ --eval_every "${EVAL_EVERY}" \
136
+ --save_every "${SAVE_EVERY}" \
137
+ --latest_every "${LATEST_EVERY}" \
138
+ --optimizer muon \
139
+ --muon_impl optax \
140
+ --lr "${LR}" \
141
+ --lr_schedule constant_warmup \
142
+ --min_lr 0 \
143
+ --weight_decay "${WEIGHT_DECAY}" \
144
+ --adam_beta1 "${ADAM_BETA1}" \
145
+ --adam_beta2 "${ADAM_BETA2}" \
146
+ --adam_eps "${ADAM_EPS}" \
147
+ --muon_momentum "${MUON_MOMENTUM}" \
148
+ --muon_ns_steps "${MUON_NS_STEPS}" \
149
+ --muon_update_scale "${MUON_UPDATE_SCALE}" \
150
+ --muon_nesterov \
151
+ --muon_width_scale \
152
+ --ema_decay "${EMA_DECAY}" \
153
+ --ema_start_step "${EMA_START_STEP}" \
154
+ --grad_clip "${GRAD_CLIP}" \
155
+ --seed 42 \
156
+ --d_model 768 \
157
+ --cond_dim 128 \
158
+ --n_layers 12 \
159
+ --n_heads 12 \
160
+ --dim_ff 3072 \
161
+ --dropout 0.0 \
162
+ --no-output_bias \
163
+ --output_init_std "${OUTPUT_INIT_STD}" \
164
+ --norm_type rmsnorm \
165
+ --model_type ddit_elf \
166
+ --elf_num_time_tokens 4 \
167
+ --elf_num_model_mode_tokens 0 \
168
+ --qk_norm \
169
+ --state_format prob \
170
+ --bridge dirichlet \
171
+ --target_loss hard_ce \
172
+ --loss_t_weight_mode "${LOSS_T_WEIGHT_MODE}" \
173
+ --loss_t_min_weight "${LOSS_T_MIN_WEIGHT}" \
174
+ --target_prob 1.0 \
175
+ --min_t 0.0 \
176
+ --max_t 1.0 \
177
+ --t_sampling_mode logit_normal \
178
+ --t_sampling_logit_mean "${T_LOGIT_MEAN}" \
179
+ --t_sampling_logit_std "${T_LOGIT_STD}" \
180
+ --t_sampling_eps 1e-4 \
181
+ --dual_t \
182
+ --corrupt_t_mode same \
183
+ --corrupt_min_t 0.0 \
184
+ --corrupt_max_t 1.0 \
185
+ --min_mask_ratio 0.1 \
186
+ --max_mask_ratio 1.0 \
187
+ --wrong_token_replace_prob 1.0 \
188
+ --wrong_token_schedule linear_t \
189
+ --wrong_token_exp_k 1.0 \
190
+ --dirichlet_concentration_min 1.0 \
191
+ --dirichlet_concentration_max 1024 \
192
+ --dirichlet_endpoint_mode categorical_dual_t \
193
+ --dirichlet_semantic_t_mode same \
194
+ --dirichlet_semantic_t_value 0.0 \
195
+ --categorical_wrong_from_full_vocab \
196
+ --simplex_bridge_sampler dirichlet \
197
+ --eps 1e-8 \
198
+ --infer_steps 1024 \
199
+ --decode_damping 1.0 \
200
+ --max_gamma 1.0 \
201
+ --decode_solver flowmap \
202
+ --noise_init logistic_normal \
203
+ --bridge_noise_init logistic_normal \
204
+ --noise_sigma -1 \
205
+ "${TF32_FLAG}" \
206
+ --activation_checkpointing \
207
+ --activation_checkpoint_scope mlp \
208
+ --ddp_gradient_as_bucket_view \
209
+ 2>&1 | tee -a "${LOG_FILE}"
LTA_openwebtext_dualt/scripts/launch_lta_owt_fullycoupled_outwd0p5_8gpu.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:-0.5}"
7
+ export WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
8
+ export RUN_NAME="${RUN_NAME:-lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_outwd0p5_nanogpt_tf32_ddit768x12_gbs512_8gpu_1m_$(date +%Y%m%d_%H%M%S)}"
9
+ export LOG_DIR="${LOG_DIR:-logs/fullycoupled_outwd0p5_8gpu}"
10
+
11
+ bash scripts/launch_lta_owt_fullycoupled_wd0p1_fp32_8gpu.sh
LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_rollin_grad_k1_rho025_subset10k_4gpu_100k.sh ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}"
7
+ export NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
8
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
9
+ export TOKENIZERS_PARALLELISM=false
10
+ export PYTHONUNBUFFERED=1
11
+ export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
12
+
13
+ free_port() {
14
+ python3 - <<'PY'
15
+ import socket
16
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
17
+ s.bind(("127.0.0.1", 0))
18
+ print(s.getsockname()[1])
19
+ PY
20
+ }
21
+
22
+ DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-t5}"
23
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
24
+ MAX_RECORDS="${MAX_RECORDS:-10000}"
25
+ TOTAL_STEPS="${TOTAL_STEPS:-100000}"
26
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
27
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-16}"
28
+ RUN_NAME="${RUN_NAME:-lta_owt_t5_3l_d256_rollin_grad_p50_k1_rho0_0p25_uniformt_maxrec10k_4gpu_100k_$(date +%Y%m%d_%H%M%S)}"
29
+ MASTER_PORT="${MASTER_PORT:-$(free_port)}"
30
+ LOG_DIR="${LOG_DIR:-logs/elfaligned_t5tokenized_4gpu}"
31
+ mkdir -p "${LOG_DIR}" "runs/${RUN_NAME}"
32
+ LOG_FILE="${LOG_DIR}/${RUN_NAME}.log"
33
+
34
+ echo "[launch] run_name=${RUN_NAME}" | tee -a "${LOG_FILE}"
35
+ echo "[launch] data=${DATA_PATH} max_records=${MAX_RECORDS} tokenizer=${TOKENIZER_PATH}" | tee -a "${LOG_FILE}"
36
+ echo "[launch] cuda=${CUDA_VISIBLE_DEVICES} nproc=${NPROC_PER_NODE} gbs=${GLOBAL_BATCH_SIZE} per_gpu=${PER_GPU_BATCH_SIZE} total_steps=${TOTAL_STEPS}" | tee -a "${LOG_FILE}"
37
+
38
+ torchrun \
39
+ --nproc_per_node="${NPROC_PER_NODE}" \
40
+ --master_port="${MASTER_PORT}" \
41
+ train.py \
42
+ --data_path "${DATA_PATH}" \
43
+ --max_records "${MAX_RECORDS}" \
44
+ --tokenized_hf \
45
+ --tokenized_pad_token pad \
46
+ --tokenizer_path "${TOKENIZER_PATH}" \
47
+ --save_dir "runs/${RUN_NAME}" \
48
+ --max_len 1024 \
49
+ --batch_size "${PER_GPU_BATCH_SIZE}" \
50
+ --global_batch_size "${GLOBAL_BATCH_SIZE}" \
51
+ --num_workers 0 \
52
+ --epochs 0 \
53
+ --total_steps "${TOTAL_STEPS}" \
54
+ --warmup_steps 1 \
55
+ --warmup_epochs 0.5 \
56
+ --log_every 100 \
57
+ --eval_every 0 \
58
+ --save_every 5000 \
59
+ --latest_every 1000 \
60
+ --optimizer muon \
61
+ --muon_impl optax \
62
+ --lr 0.002 \
63
+ --lr_schedule constant_warmup \
64
+ --min_lr 0.0 \
65
+ --weight_decay 0.1 \
66
+ --output_weight_decay -1 \
67
+ --adamw_param_groups nanogpt \
68
+ --adam_beta1 0.9 \
69
+ --adam_beta2 0.999 \
70
+ --adam_eps 1e-8 \
71
+ --ema_decay 0.9999 \
72
+ --ema_start_step 0 \
73
+ --grad_clip 1.0 \
74
+ --seed 42 \
75
+ --d_model 256 \
76
+ --cond_dim 128 \
77
+ --n_layers 3 \
78
+ --n_heads 4 \
79
+ --dim_ff 1024 \
80
+ --dropout 0.0 \
81
+ --no-output_bias \
82
+ --output_init_std 0 \
83
+ --norm_type rmsnorm \
84
+ --qk_norm \
85
+ --model_type ddit_elf \
86
+ --ddit_mlp_type gelu \
87
+ --state_format prob \
88
+ --bridge dirichlet \
89
+ --target_loss hard_ce \
90
+ --loss_t_weight_mode none \
91
+ --loss_t_min_weight 0.0 \
92
+ --rollout_train_prob 0.50 \
93
+ --rollout_train_time_mode sampled_path \
94
+ --rollout_train_steps 1 \
95
+ --rollout_train_steps_min -1 \
96
+ --rollout_train_infer_steps 1 \
97
+ --rollout_train_s_dist uniform \
98
+ --rollout_train_s_min_frac 0.0 \
99
+ --rollout_train_s_max_frac 0.25 \
100
+ --rollout_train_temp 1.0 \
101
+ --rollout_train_max_gamma 1.0 \
102
+ --rollout_train_corrupt_only \
103
+ --rollout_train_samplewise \
104
+ --rollout_train_selected_only \
105
+ --no-rollout_train_compute_always \
106
+ --rollout_train_keep_grad \
107
+ --rollout_train_sync_t \
108
+ --target_prob 1.0 \
109
+ --min_t 0.0 \
110
+ --max_t 1.0 \
111
+ --t_sampling_mode uniform \
112
+ --t_sampling_logit_mean -1.5 \
113
+ --t_sampling_logit_std 0.8 \
114
+ --t_sampling_eps 1e-4 \
115
+ --dual_t \
116
+ --corrupt_t_mode same \
117
+ --corrupt_min_t 0.0 \
118
+ --corrupt_max_t 1.0 \
119
+ --min_mask_ratio 1.0 \
120
+ --max_mask_ratio 1.0 \
121
+ --mask_mixture_original_prob 0.0 \
122
+ --mask_mixture_lowk_prob 0.0 \
123
+ --mask_mixture_lowcorrupt_prob 0.0 \
124
+ --mask_mixture_block_prob 0.0 \
125
+ --mask_mixture_all_prob 1.0 \
126
+ --wrong_token_replace_prob 1.0 \
127
+ --wrong_token_schedule linear_t \
128
+ --wrong_token_exp_k 1.0 \
129
+ --dirichlet_concentration_min 1.0 \
130
+ --dirichlet_concentration_max 1024 \
131
+ --dirichlet_endpoint_mode categorical_dual_t \
132
+ --dirichlet_semantic_t_mode same \
133
+ --dirichlet_semantic_t_value 0.0 \
134
+ --categorical_wrong_from_full_vocab \
135
+ --simplex_bridge_sampler dirichlet \
136
+ --eps 1e-8 \
137
+ --infer_steps 1024 \
138
+ --decode_damping 1.0 \
139
+ --max_gamma 1.0 \
140
+ --decode_solver flowmap \
141
+ --noise_init logistic_normal \
142
+ --bridge_noise_init logistic_normal \
143
+ --noise_sigma -1 \
144
+ --allow_tf32 \
145
+ --activation_checkpointing \
146
+ --activation_checkpoint_scope mlp \
147
+ --ddp_gradient_as_bucket_view \
148
+ 2>&1 | tee -a "${LOG_FILE}"
LTA_openwebtext_dualt/scripts/run_lta_lm1b_dirichlet_len1024_Cv_to_2v_8gpu_1m_save10k.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
7
+ export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
8
+ export MASTER_PORT="${MASTER_PORT:-32682}"
9
+
10
+ export GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
11
+ export PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
12
+ export TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
13
+ export WARMUP_STEPS="${WARMUP_STEPS:-2500}"
14
+ export SAVE_EVERY="${SAVE_EVERY:-10000}"
15
+ export LATEST_EVERY="${LATEST_EVERY:-1000}"
16
+ export LOG_EVERY="${LOG_EVERY:-100}"
17
+
18
+ export MAX_LEN="${MAX_LEN:-1024}"
19
+ export VOCAB_SIZE="${VOCAB_SIZE:-30522}"
20
+ export CMIN="${CMIN:-${VOCAB_SIZE}}"
21
+ export CMAX="${CMAX:-61044}"
22
+
23
+ export MIN_MASK_RATIO="${MIN_MASK_RATIO:-0.1}"
24
+ export MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
25
+ export CATEGORICAL_WRONG_PROB_FLOOR="${CATEGORICAL_WRONG_PROB_FLOOR:-0.0}"
26
+
27
+ # Keep watcher off by default for the 1M run; enable explicitly to avoid
28
+ # competing with training GPUs on busy 8-card nodes.
29
+ export WATCH_ENABLED="${WATCH_ENABLED:-0}"
30
+
31
+ DATE_TAG="${DATE_TAG:-$(date +%Y%m%d)}"
32
+ export RUN_NAME="${RUN_NAME:-lta_lm1b_dirichlet_len1024_Cv_to_2v_gbs512_b32_8gpu_1m_save10k_${DATE_TAG}}"
33
+
34
+ bash scripts/run_lta_lm1b_dirichlet_len1024_Cv_to_2v_8gpu_save1k_with_gumbel_watch.sh
LTA_openwebtext_dualt/scripts/run_lta_owt_dirichlet_len1024_Cv_to_2v_8gpu_1m_save10k.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
7
+ export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
8
+ export MASTER_PORT="${MASTER_PORT:-32682}"
9
+
10
+ export GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
11
+ export PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
12
+ export TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
13
+ export WARMUP_STEPS="${WARMUP_STEPS:-2500}"
14
+ export SAVE_EVERY="${SAVE_EVERY:-10000}"
15
+ export LATEST_EVERY="${LATEST_EVERY:-1000}"
16
+ export LOG_EVERY="${LOG_EVERY:-100}"
17
+
18
+ export MAX_LEN="${MAX_LEN:-1024}"
19
+ export VOCAB_SIZE="${VOCAB_SIZE:-30522}"
20
+ export CMIN="${CMIN:-${VOCAB_SIZE}}"
21
+ export CMAX="${CMAX:-61044}"
22
+
23
+ export MIN_MASK_RATIO="${MIN_MASK_RATIO:-0.1}"
24
+ export MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
25
+ export CATEGORICAL_WRONG_PROB_FLOOR="${CATEGORICAL_WRONG_PROB_FLOOR:-0.0}"
26
+
27
+ # Keep watcher off by default for the 1M run; enable explicitly to avoid
28
+ # competing with training GPUs on busy 8-card nodes.
29
+ export WATCH_ENABLED="${WATCH_ENABLED:-0}"
30
+
31
+ DATE_TAG="${DATE_TAG:-$(date +%Y%m%d)}"
32
+ export RUN_NAME="${RUN_NAME:-lta_owt_dirichlet_len1024_Cv_to_2v_gbs512_b32_8gpu_1m_save10k_${DATE_TAG}}"
33
+
34
+ bash scripts/run_lta_owt_dirichlet_len1024_Cv_to_2v_8gpu_save1k_with_gumbel_watch.sh
LTA_openwebtext_dualt/scripts/run_lta_owt_t5_absrope_adaln_dirichlet_len1024_Cv_to_2v_8gpu_mask0p1_1p0_sameT_1m_save10k.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ # T5-tokenized OWT, DDiT = RoPE + adaLN-zero, with learned absolute position
7
+ # embeddings added before RoPE. The bridge/model t is shared (sameT).
8
+ export DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-t5}"
9
+ export TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
10
+ export TOKENIZED_HF="${TOKENIZED_HF:-1}"
11
+ export TOKENIZED_PAD_TOKEN="${TOKENIZED_PAD_TOKEN:-pad}"
12
+
13
+ export VOCAB_SIZE="${VOCAB_SIZE:-32100}"
14
+ export CMIN="${CMIN:-32100}"
15
+ export CMAX="${CMAX:-64200}"
16
+
17
+ export ABS_POS_EMBED="${ABS_POS_EMBED:-1}"
18
+ export CORRUPT_T_MODE="${CORRUPT_T_MODE:-same}"
19
+ export MIN_MASK_RATIO="${MIN_MASK_RATIO:-0.1}"
20
+ export MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
21
+ export MASK_MIXTURE_ORIGINAL_PROB="${MASK_MIXTURE_ORIGINAL_PROB:-0.0}"
22
+ export MASK_MIXTURE_ALL_PROB="${MASK_MIXTURE_ALL_PROB:-0.0}"
23
+
24
+ export DATE_TAG="${DATE_TAG:-$(date +%Y%m%d)}"
25
+ export TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
26
+ export SAVE_EVERY="${SAVE_EVERY:-10000}"
27
+ export LATEST_EVERY="${LATEST_EVERY:-1000}"
28
+ export WATCH_ENABLED="${WATCH_ENABLED:-1}"
29
+ export WATCH_STEP_INTERVAL="${WATCH_STEP_INTERVAL:-10000}"
30
+ export WATCH_N_SAMPLES="${WATCH_N_SAMPLES:-128}"
31
+ export WATCH_CUDA_VISIBLE_DEVICES="${WATCH_CUDA_VISIBLE_DEVICES:-7}"
32
+
33
+ export RUN_NAME="${RUN_NAME:-lta_owt_t5_absrope_adaln_dirichlet_len1024_Cv_to_2v_mask0p1_1p0_sameT_gbs512_b32_8gpu_1m_save10k_${DATE_TAG}}"
34
+ export WATCH_OUT_BASE="${WATCH_OUT_BASE:-docs/lta_samples/metrics_${DATE_TAG}/owt_t5_absrope_adaln_Cv_to_2v_mask0p1_1p0_sameT_sde_gumbel_topp${WATCH_ENDPOINT_TOP_P:-0.95}_tau${WATCH_GUMBEL_TAU_START:-1.0}_to_${WATCH_GUMBEL_TAU_END:-0.2}_blend_c${CMIN}_${CMAX}_n${WATCH_N_SAMPLES}/${RUN_NAME}}"
35
+
36
+ bash scripts/run_lta_owt_dirichlet_len1024_Cv_to_2v_8gpu_save1k_with_gumbel_watch.sh
LTA_openwebtext_dualt/scripts/run_train8_wrong_floor_pilots_4gpu.sh ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
7
+ export TOKENIZERS_PARALLELISM=false
8
+ export PYTHONUNBUFFERED=1
9
+
10
+ BASE_CACHE="${BASE_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks}"
11
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
12
+ MAX_LEN="${MAX_LEN:-256}"
13
+ N_SAMPLES="${N_SAMPLES:-64}"
14
+ INFER_STEPS="${INFER_STEPS:-128}"
15
+ STEP_CHUNK="${STEP_CHUNK:-1000}"
16
+ MAX_TOTAL_STEPS="${MAX_TOTAL_STEPS:-20000}"
17
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-128}"
18
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
19
+ GROUP_STAMP="${GROUP_STAMP:-$(date +%Y%m%d_%H%M%S)}"
20
+ OUT_ROOT="${OUT_ROOT:-docs/lta_samples/metrics_20260517/wrong_floor_pilots_len${MAX_LEN}_bs512_ode128_${GROUP_STAMP}}"
21
+ DRIVER_LOG="${DRIVER_LOG:-logs/wrong_floor_pilots_4gpu/${GROUP_STAMP}.log}"
22
+ CURVE_CSV="${CURVE_CSV:-${OUT_ROOT}/hit_ratio_curve.csv}"
23
+ mkdir -p "$(dirname "${DRIVER_LOG}")" "${OUT_ROOT}"
24
+
25
+ cache="${BASE_CACHE}/gpt2_len${MAX_LEN}_train8_compact_overfit"
26
+ vocab_size="$(
27
+ python - "$cache" <<'PY'
28
+ import json
29
+ import sys
30
+ from pathlib import Path
31
+ meta = json.loads((Path(sys.argv[1]) / "meta.json").read_text())
32
+ print(int(meta.get("compact_vocab_size", meta.get("vocab_size"))))
33
+ PY
34
+ )"
35
+
36
+ if [[ ! -f "${CURVE_CSV}" ]]; then
37
+ echo "config,ckpt_step,train_views_seen,train_tokens_seen,token_acc_mean,exact_count,exact_ref_count,exact_ref_hits" > "${CURVE_CSV}"
38
+ fi
39
+
40
+ latest_step() {
41
+ local run_name="$1"
42
+ python - "$run_name" <<'PY'
43
+ import re
44
+ import sys
45
+ from pathlib import Path
46
+ run = Path("runs") / sys.argv[1]
47
+ steps = []
48
+ for path in run.glob("step_*.pt"):
49
+ m = re.search(r"step_(\d+)\.pt$", path.name)
50
+ if m:
51
+ steps.append(int(m.group(1)))
52
+ print(max(steps) if steps else 0)
53
+ PY
54
+ }
55
+
56
+ free_port() {
57
+ python - <<'PY'
58
+ import socket
59
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
60
+ s.bind(("127.0.0.1", 0))
61
+ print(s.getsockname()[1])
62
+ PY
63
+ }
64
+
65
+ eval_latest() {
66
+ local config="$1"
67
+ local run_name="$2"
68
+ local target_step="$3"
69
+ local out_dir="${OUT_ROOT}/${config}/step_${target_step}"
70
+ mkdir -p "${out_dir}"
71
+ CUDA_VISIBLE_DEVICES="${EVAL_CUDA_VISIBLE_DEVICES:-0}" python scripts/eval_train8_decode_acc.py \
72
+ --runs_glob "runs/${run_name}" \
73
+ --data_dir "${cache}" \
74
+ --tokenizer_path "${TOKENIZER_PATH}" \
75
+ --out_dir "${out_dir}" \
76
+ --max_len "${MAX_LEN}" \
77
+ --n_samples "${N_SAMPLES}" \
78
+ --batch_size "${N_SAMPLES}" \
79
+ --latest_only \
80
+ --endpoint_softenings none \
81
+ --steps "${INFER_STEPS}" \
82
+ --decode_rule flowmap \
83
+ --time_schedule logit_normal \
84
+ --time_logit_mean -1.5 \
85
+ --time_logit_std 0.8 \
86
+ --model_t_mode post \
87
+ --c_min 1 \
88
+ --c_max 512 \
89
+ --late_temp 1.0 \
90
+ --final_from state \
91
+ --final_decode argmax
92
+ python - "$config" "$out_dir" "$N_SAMPLES" "$GLOBAL_BATCH_SIZE" "$MAX_LEN" "$CURVE_CSV" <<'PY'
93
+ import json
94
+ import sys
95
+ from pathlib import Path
96
+ config = sys.argv[1]
97
+ out = Path(sys.argv[2])
98
+ n = int(sys.argv[3])
99
+ global_batch = int(sys.argv[4])
100
+ max_len = int(sys.argv[5])
101
+ curve = Path(sys.argv[6])
102
+ row = json.loads((out / "decode_token_acc.jsonl").read_text().splitlines()[-1])
103
+ views = int(row["ckpt_step"]) * global_batch
104
+ tokens = views * max_len
105
+ print(
106
+ "RESULT "
107
+ f"config={config} ckpt_step={row['ckpt_step']} views={views} "
108
+ f"token_acc={row['token_acc_mean']:.4f} exact={row['exact_count']}/{n} "
109
+ f"exact_refs={row['exact_ref_count']} hits={row['exact_ref_hits']}",
110
+ flush=True,
111
+ )
112
+ with curve.open("a", encoding="utf-8") as f:
113
+ f.write(
114
+ f"{config},{row['ckpt_step']},{views},{tokens},{row['token_acc_mean']},"
115
+ f"{row['exact_count']},{row['exact_ref_count']},\"{row['exact_ref_hits']}\"\n"
116
+ )
117
+ PY
118
+ }
119
+
120
+ configs=(
121
+ wrongfloor0p3
122
+ wrongfloor0p5
123
+ wrongfloor0p7
124
+ )
125
+
126
+ echo "[wrong-floor] start stamp=${GROUP_STAMP} len=${MAX_LEN} vocab=${vocab_size} out=${OUT_ROOT}" | tee -a "${DRIVER_LOG}"
127
+ round_idx=0
128
+ while :; do
129
+ round_idx=$((round_idx + 1))
130
+ active=0
131
+ echo "[wrong-floor] round=${round_idx} $(date)" | tee -a "${DRIVER_LOG}"
132
+ for config in "${configs[@]}"; do
133
+ floor="${config#wrongfloor}"
134
+ floor="${floor//p/.}"
135
+ run_name="train8_wrongfloor_len${MAX_LEN}_${config}_${GROUP_STAMP}"
136
+ step_now="$(latest_step "${run_name}")"
137
+ if [[ "${step_now}" -ge "${MAX_TOTAL_STEPS}" ]]; then
138
+ echo "[wrong-floor] capped config=${config} step=${step_now}" | tee -a "${DRIVER_LOG}"
139
+ continue
140
+ fi
141
+ active=1
142
+ target_step=$((step_now + STEP_CHUNK))
143
+ if [[ "${target_step}" -gt "${MAX_TOTAL_STEPS}" ]]; then
144
+ target_step="${MAX_TOTAL_STEPS}"
145
+ fi
146
+ resume_path=""
147
+ if [[ -f "runs/${run_name}/latest.pt" ]]; then
148
+ resume_path="runs/${run_name}/latest.pt"
149
+ fi
150
+ echo "[wrong-floor] train config=${config} floor=${floor} from=${step_now} to=${target_step}" | tee -a "${DRIVER_LOG}"
151
+ CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}" \
152
+ NPROC_PER_NODE="${NPROC_PER_NODE:-4}" \
153
+ MASTER_PORT="$(free_port)" \
154
+ OWT_CHUNK_CACHE_DIR="${cache}" \
155
+ OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK:-64}" \
156
+ MAX_LEN="${MAX_LEN}" \
157
+ VOCAB_SIZE_OVERRIDE="${vocab_size}" \
158
+ D_MODEL="${D_MODEL:-192}" \
159
+ COND_DIM="${COND_DIM:-64}" \
160
+ N_LAYERS="${N_LAYERS:-3}" \
161
+ N_HEADS="${N_HEADS:-3}" \
162
+ DIM_FF="${DIM_FF:-768}" \
163
+ TOTAL_STEPS="${target_step}" \
164
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE}" \
165
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE}" \
166
+ NUM_WORKERS="${NUM_WORKERS:-0}" \
167
+ LOG_EVERY="${LOG_EVERY:-100}" \
168
+ SAVE_EVERY="${STEP_CHUNK}" \
169
+ LATEST_EVERY="${STEP_CHUNK}" \
170
+ WARMUP_STEPS="${WARMUP_STEPS:-10}" \
171
+ LEARNING_RATE="${LEARNING_RATE:-0.002}" \
172
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}" \
173
+ MUON_IMPL="${MUON_IMPL:-legacy}" \
174
+ OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}" \
175
+ TARGET_LOSS=hard_ce \
176
+ MIN_MASK_RATIO=1.0 \
177
+ MAX_MASK_RATIO=1.0 \
178
+ MASK_MIXTURE_LOWK_PROB=0.0 \
179
+ MASK_MIXTURE_ALL_PROB=1.0 \
180
+ LOWK_CLEAN_TOKENS=0 \
181
+ CLEAN_STATE_MODE=onehot \
182
+ ROLLOUT_TRAIN_PROB=0.0 \
183
+ CATEGORICAL_WRONG_PROB_FLOOR="${floor}" \
184
+ RUN_NAME="${run_name}" \
185
+ RESUME_PATH="${resume_path}" \
186
+ bash scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh
187
+ echo "[wrong-floor] eval config=${config} step=${target_step}" | tee -a "${DRIVER_LOG}"
188
+ eval_latest "${config}" "${run_name}" "${target_step}" | tee -a "${DRIVER_LOG}"
189
+ done
190
+ if [[ "${active}" -eq 0 ]]; then
191
+ echo "[wrong-floor] all capped $(date)" | tee -a "${DRIVER_LOG}"
192
+ break
193
+ fi
194
+ done
LTA_openwebtext_dualt/scripts/watch_infer_owt_classic_fullvocab_len1024_lr2e4_gbs2048_latest_every1k_t1p45.sh ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
7
+ export TOKENIZERS_PARALLELISM=false
8
+ export PYTHONUNBUFFERED=1
9
+
10
+ # Watch the 16-GPU OWT classic full-vocab len1024/lr2e-4/GBS2048 run.
11
+ # The training command saves step_*.pt every 10k but latest.pt every 1k, so this
12
+ # watcher snapshots stable latest.pt at each new 1k step before running infer.
13
+
14
+ RUN_GLOB="${RUN_GLOB:-runs/lta_owt_classic_fullvocab_bert_c1024_len1024_lr2e4_gbs2048_2node8gpu_1m_save10k_*}"
15
+ RUN_DIR="${RUN_DIR:-}"
16
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/workspace/imagenet_handoff_20260327/nlp_dts_light/assets/distilbert-base-uncased/tokenizer.json}"
17
+ SCORER="${SCORER:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-large-standard}"
18
+
19
+ CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
20
+ N_SAMPLES="${N_SAMPLES:-1024}"
21
+ STEPS="${STEPS:-128}"
22
+ CMAX="${CMAX:-1024}"
23
+ TEMP="${TEMP:-1.45}"
24
+ MAX_LEN="${MAX_LEN:-1024}"
25
+ DECODE_BATCH="${DECODE_BATCH:-1}"
26
+ SCORE_BATCH="${SCORE_BATCH:-1}"
27
+ SCORE_MAX_LENGTH="${SCORE_MAX_LENGTH:-1024}"
28
+ SLEEP_SECONDS="${SLEEP_SECONDS:-60}"
29
+ STEP_INTERVAL="${STEP_INTERVAL:-1000}"
30
+ DATE_TAG="${DATE_TAG:-$(date +%Y%m%d)}"
31
+
32
+ TEMP_TAG="${TEMP//./p}"
33
+ LOG_DIR="${LOG_DIR:-logs/owt_classic_fullvocab_len1024_lr2e4_gbs2048_infer_watch}"
34
+ OUT_ROOT="${OUT_ROOT:-docs/lta_samples/metrics_${DATE_TAG}/owt_classic_fullvocab_len1024_lr2e4_gbs2048_latest_every1k_normal_steps_state_t${TEMP_TAG}_c${CMAX}_n${N_SAMPLES}}"
35
+
36
+ mkdir -p "${LOG_DIR}" "${OUT_ROOT}"
37
+
38
+ find_run_dir() {
39
+ if [[ -n "${RUN_DIR}" ]]; then
40
+ if [[ -d "${RUN_DIR}" ]]; then
41
+ printf '%s\n' "${RUN_DIR}"
42
+ return 0
43
+ fi
44
+ return 1
45
+ fi
46
+ shopt -s nullglob
47
+ local matches=( ${RUN_GLOB} )
48
+ shopt -u nullglob
49
+ if (( ${#matches[@]} == 0 )); then
50
+ return 1
51
+ fi
52
+ ls -td "${matches[@]}" 2>/dev/null | head -1
53
+ }
54
+
55
+ wait_for_stable_file() {
56
+ local path="$1"
57
+ local stat_a stat_b
58
+ stat_a="$(stat -c '%s:%Y' "${path}" 2>/dev/null || echo missing)"
59
+ sleep 20
60
+ stat_b="$(stat -c '%s:%Y' "${path}" 2>/dev/null || echo changed)"
61
+ [[ "${stat_a}" == "${stat_b}" && "${stat_a}" != "missing" ]]
62
+ }
63
+
64
+ read_ckpt_step() {
65
+ local ckpt="$1"
66
+ python - "$ckpt" <<'PY'
67
+ import sys
68
+ import torch
69
+ ckpt = torch.load(sys.argv[1], map_location="cpu", weights_only=False)
70
+ step = ckpt.get("step")
71
+ if step is None:
72
+ raise SystemExit("checkpoint has no step")
73
+ print(int(step))
74
+ PY
75
+ }
76
+
77
+ echo "[watch-owt-len1024-lr2e4] run_glob=${RUN_GLOB}"
78
+ echo "[watch-owt-len1024-lr2e4] explicit_run_dir=${RUN_DIR:-<auto>}"
79
+ echo "[watch-owt-len1024-lr2e4] out_root=${OUT_ROOT}"
80
+ echo "[watch-owt-len1024-lr2e4] decode=normal_steps_sweep steps=${STEPS} cmax=${CMAX} temp=${TEMP} final_from=state n=${N_SAMPLES} max_len=${MAX_LEN}"
81
+ echo "[watch-owt-len1024-lr2e4] source=latest.pt snapshot_each=${STEP_INTERVAL} decode_batch=${DECODE_BATCH} score_batch=${SCORE_BATCH}"
82
+
83
+ while true; do
84
+ current_run_dir="$(find_run_dir || true)"
85
+ if [[ -z "${current_run_dir}" ]]; then
86
+ echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) waiting for matching run: ${RUN_GLOB}"
87
+ sleep "${SLEEP_SECONDS}"
88
+ continue
89
+ fi
90
+
91
+ run_stem="$(basename "${current_run_dir}")"
92
+ latest_ckpt="${current_run_dir}/latest.pt"
93
+ out_base="${OUT_ROOT}/${run_stem}"
94
+ processed_file="${LOG_DIR}/processed_${run_stem}_steps${STEPS}_c${CMAX}_t${TEMP_TAG}_n${N_SAMPLES}.txt"
95
+ snapshot_dir="${current_run_dir}/latest_snapshots_1k"
96
+ mkdir -p "${out_base}" "${LOG_DIR}" "${snapshot_dir}"
97
+ touch "${processed_file}"
98
+
99
+ if [[ ! -f "${latest_ckpt}" ]]; then
100
+ echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) run=${run_stem} no latest.pt yet"
101
+ sleep "${SLEEP_SECONDS}"
102
+ continue
103
+ fi
104
+ if ! wait_for_stable_file "${latest_ckpt}"; then
105
+ echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) latest.pt not stable yet"
106
+ sleep "${SLEEP_SECONDS}"
107
+ continue
108
+ fi
109
+
110
+ step_num="$(read_ckpt_step "${latest_ckpt}")"
111
+ if (( step_num <= 0 || step_num % STEP_INTERVAL != 0 )); then
112
+ echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) latest step=${step_num}; waiting for multiple of ${STEP_INTERVAL}"
113
+ sleep "${SLEEP_SECONDS}"
114
+ continue
115
+ fi
116
+
117
+ step="$(printf '%07d' "${step_num}")"
118
+ snapshot="${snapshot_dir}/step_${step}.pt"
119
+ processed_key="${current_run_dir}:step_${step}"
120
+ if grep -Fxq "${processed_key}" "${processed_file}"; then
121
+ sleep "${SLEEP_SECONDS}"
122
+ continue
123
+ fi
124
+
125
+ if [[ ! -f "${snapshot}" ]]; then
126
+ tmp_snapshot="${snapshot}.tmp.$$"
127
+ echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) snapshot latest step_${step} -> ${snapshot}"
128
+ cp --reflink=auto "${latest_ckpt}" "${tmp_snapshot}" 2>/dev/null || cp "${latest_ckpt}" "${tmp_snapshot}"
129
+ mv "${tmp_snapshot}" "${snapshot}"
130
+ fi
131
+
132
+ out_dir="${out_base}/step_${step}"
133
+ log_file="${LOG_DIR}/infer_${run_stem}_step_${step}_t${TEMP_TAG}.log"
134
+ mkdir -p "${out_dir}"
135
+
136
+ echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) infer ${snapshot} -> ${out_dir}" | tee -a "${log_file}"
137
+ CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES}" python scripts/eval_owt_normal_steps_sweep_20260515.py \
138
+ --checkpoint "${snapshot}" \
139
+ --tokenizer_path "${TOKENIZER_PATH}" \
140
+ --scorer "${SCORER}" \
141
+ --out_dir "${out_dir}" \
142
+ --steps_list "${STEPS}" \
143
+ --cmax_list "${CMAX}" \
144
+ --endpoint_temps "${TEMP}" \
145
+ --n_samples "${N_SAMPLES}" \
146
+ --max_len "${MAX_LEN}" \
147
+ --decode_batch "${DECODE_BATCH}" \
148
+ --score_batch "${SCORE_BATCH}" \
149
+ --score_max_length "${SCORE_MAX_LENGTH}" \
150
+ --detokenizer none \
151
+ --seed 20260521 \
152
+ --save_samples 16 \
153
+ 2>&1 | tee -a "${log_file}"
154
+
155
+ echo "${processed_key}" >> "${processed_file}"
156
+ echo "[watch-owt-len1024-lr2e4] $(date +%F_%T) done step_${step}" | tee -a "${log_file}"
157
+ sleep "${SLEEP_SECONDS}"
158
+ done