JinghuiLuAstronaut commited on
Commit
b6cb07d
·
verified ·
1 Parent(s): 80373af

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. LTA_openwebtext_dualt/logs/train8_len_sweep_compact_bs512_until_exact_4gpu/nohup.log +0 -0
  2. LTA_openwebtext_dualt/scripts/eval_ar_8gpu_latest_temp_sweep_20260506.py +28 -0
  3. LTA_openwebtext_dualt/scripts/eval_dirichlet_latest_key3_20260508.py +51 -0
  4. LTA_openwebtext_dualt/scripts/eval_lm1b_c1024_fullycoupled_8gpu_1m_sde.sh +39 -0
  5. LTA_openwebtext_dualt/scripts/extract_mauve_features_lm1b.py +110 -0
  6. LTA_openwebtext_dualt/scripts/launch_lta_lm1b_fullycoupled_t5_logitnormal_8gpu.sh +6 -0
  7. LTA_openwebtext_dualt/scripts/launch_lta_owt_compact_gpt2bpe_v2048_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu.sh +250 -0
  8. LTA_openwebtext_dualt/scripts/launch_lta_owt_fullycoupled_elftsampling_mask1_wd0p1_fp32_8gpu.sh +202 -0
  9. LTA_openwebtext_dualt/scripts/launch_lta_owt_fullycoupled_uniform_mask1_swiglu_wd0p1_fp32_8gpu.sh +204 -0
  10. LTA_openwebtext_dualt/scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh +271 -0
  11. LTA_openwebtext_dualt/scripts/launch_lta_owt_selfcond_p05_4gpu.sh +65 -0
  12. LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_blockar_twostream_8gpu.sh +120 -0
  13. LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_rollin_p50_randk0_4_8gpu.sh +60 -0
  14. LTA_openwebtext_dualt/scripts/make_compact_gpt2_bpe_tokenizer.py +160 -0
  15. LTA_openwebtext_dualt/scripts/run_lta_lm1b_linear_simplex_len128_4gpu_save1k_with_watch.sh +217 -0
  16. LTA_openwebtext_dualt/scripts/run_lta_owt_t5elf_absrope_time4_dirichlet_len1025_C1_to_1024_8gpu_1m_mask1_sameT_prebos_save10k.sh +74 -0
  17. LTA_openwebtext_dualt/scripts/run_train8_ctx1024_rollin_sweep_4gpu.sh +235 -0
  18. LTA_openwebtext_dualt/scripts/run_train8_len256_headwd_sweep_bs512_until_exact_4gpu.sh +287 -0
  19. LTA_openwebtext_dualt/scripts/run_train8_noise_geometry_pilots_4gpu.sh +289 -0
  20. LTA_openwebtext_dualt/scripts/sweep_categorical_c1024_diffusion_finalsample_20260506.py +48 -0
LTA_openwebtext_dualt/logs/train8_len_sweep_compact_bs512_until_exact_4gpu/nohup.log ADDED
The diff for this file is too large to render. See raw diff
 
LTA_openwebtext_dualt/scripts/eval_ar_8gpu_latest_temp_sweep_20260506.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import importlib.util
5
+ import sys
6
+ from pathlib import Path
7
+
8
+
9
+ BASE_PATH = Path(__file__).with_name("eval_lm1b_200k_methods_genppl_20260506.py")
10
+ spec = importlib.util.spec_from_file_location("eval_lm1b_200k_methods_genppl_20260506", BASE_PATH)
11
+ if spec is None or spec.loader is None:
12
+ raise RuntimeError(f"Could not load {BASE_PATH}")
13
+ base = importlib.util.module_from_spec(spec)
14
+ sys.modules[spec.name] = base
15
+ spec.loader.exec_module(base)
16
+
17
+
18
+ base.TARGETS = [
19
+ base.EvalTarget(
20
+ "ar_8gpu_scratch_latest",
21
+ "ar",
22
+ "runs/ar_lm1b_flmpack_bert_small_len128_gbs512_8gpu_1m_scratch_20260505/latest.pt",
23
+ )
24
+ ]
25
+
26
+
27
+ if __name__ == "__main__":
28
+ base.main()
LTA_openwebtext_dualt/scripts/eval_dirichlet_latest_key3_20260508.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import importlib.util
5
+ import sys
6
+ from pathlib import Path
7
+
8
+
9
+ BASE = Path(__file__).with_name("eval_c1024_decode_sweep_20260507.py")
10
+ spec = importlib.util.spec_from_file_location("eval_c1024_decode_sweep_20260507", BASE)
11
+ if spec is None or spec.loader is None:
12
+ raise RuntimeError(f"cannot import {BASE}")
13
+ base = importlib.util.module_from_spec(spec)
14
+ sys.modules[spec.name] = base
15
+ spec.loader.exec_module(base)
16
+
17
+
18
+ def key_configs() -> list[base.DecodeConfig]:
19
+ return [
20
+ base.DecodeConfig(
21
+ "match_post_sem1_blend_c16_t1p3",
22
+ "post",
23
+ 1.0,
24
+ 1.0,
25
+ "blend",
26
+ endpoint_temp=1.3,
27
+ concentration_max=16.0,
28
+ ),
29
+ base.DecodeConfig(
30
+ "match_post_sem1_blend_c64_t1p3",
31
+ "post",
32
+ 1.0,
33
+ 1.0,
34
+ "blend",
35
+ endpoint_temp=1.3,
36
+ concentration_max=64.0,
37
+ ),
38
+ base.DecodeConfig(
39
+ "match_post_sem1_blend_c1024_t1p3",
40
+ "post",
41
+ 1.0,
42
+ 1.0,
43
+ "blend",
44
+ endpoint_temp=1.3,
45
+ concentration_max=1024.0,
46
+ ),
47
+ ]
48
+
49
+
50
+ base.default_configs = key_configs
51
+ base.main()
LTA_openwebtext_dualt/scripts/eval_lm1b_c1024_fullycoupled_8gpu_1m_sde.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # SDE (Dirichlet resample) GenPPL eval for:
3
+ # lta_lm1b_dirichlet_categorical_fullvocab_c1024_fullycoupled_flmpack_onehot_hardce_ddit_small_len128_gbs512_8gpu_1m_nw0
4
+ set -euo pipefail
5
+
6
+ ROOT="/e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt"
7
+ cd "${ROOT}"
8
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
9
+ export TOKENIZERS_PARALLELISM=false
10
+
11
+ RUN_DIR="${RUN_DIR:-runs/lta_lm1b_dirichlet_categorical_fullvocab_c1024_fullycoupled_flmpack_onehot_hardce_ddit_small_len128_gbs512_8gpu_1m_nw0}"
12
+ CHECKPOINT="${CHECKPOINT:-${RUN_DIR}/latest.pt}"
13
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/workspace/imagenet_handoff_20260327/nlp_dts_light/assets/distilbert-base-uncased/tokenizer.json}"
14
+ SCORER="${SCORER:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-large-standard}"
15
+
16
+ STEPS="${STEPS:-1024}"
17
+ N_SAMPLES="${N_SAMPLES:-128}"
18
+ MAX_LEN="${MAX_LEN:-128}"
19
+ BATCH_SIZE="${BATCH_SIZE:-16}"
20
+ SCORE_BATCH="${SCORE_BATCH:-8}"
21
+
22
+ OUT_DIR="${OUT_DIR:-${RUN_DIR}/eval_sde_steps${STEPS}_n${N_SAMPLES}}"
23
+
24
+ CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
25
+
26
+ echo "[eval-sde] checkpoint=${CHECKPOINT}"
27
+ echo "[eval-sde] steps=${STEPS} n_samples=${N_SAMPLES} max_len=${MAX_LEN}"
28
+ echo "[eval-sde] out_dir=${OUT_DIR}"
29
+
30
+ CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES}" python scripts/eval_lm1b_c1024_fullycoupled_sde_genppl.py \
31
+ --checkpoint "${CHECKPOINT}" \
32
+ --tokenizer_path "${TOKENIZER_PATH}" \
33
+ --scorer "${SCORER}" \
34
+ --out_dir "${OUT_DIR}" \
35
+ --steps "${STEPS}" \
36
+ --n_samples "${N_SAMPLES}" \
37
+ --max_len "${MAX_LEN}" \
38
+ --batch_size "${BATCH_SIZE}" \
39
+ --score_batch "${SCORE_BATCH}"
LTA_openwebtext_dualt/scripts/extract_mauve_features_lm1b.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from transformers import AutoModel, AutoTokenizer
10
+
11
+
12
+ IN_PATH = Path("runs/decode_lab/mauve_export_lm1b_latest_1024.jsonl")
13
+ OUT_PATH = Path("runs/decode_lab/mauve_features_lm1b_latest_1024_gpt2large.npz")
14
+ MODEL_PATH = "/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-large-standard"
15
+ MAX_LEN = 256
16
+ BATCH_SIZE = 16
17
+
18
+
19
+ def load_texts() -> tuple[list[str], list[str], dict[str, list[str]], dict[str, list[str]]]:
20
+ refs_raw: list[str] = []
21
+ refs_stripped: list[str] = []
22
+ gen_raw: dict[str, list[str]] = {}
23
+ gen_stripped: dict[str, list[str]] = {}
24
+ with IN_PATH.open(encoding="utf-8") as f:
25
+ for line in f:
26
+ obj = json.loads(line)
27
+ typ = obj.get("type")
28
+ if typ == "reference":
29
+ refs_raw.append(obj["raw_text"])
30
+ refs_stripped.append(obj["stripped_text"])
31
+ elif typ == "generated":
32
+ setting = obj["setting"]
33
+ key = f"t{float(setting['endpoint_temp']):.2f}_{setting['final_from']}"
34
+ gen_raw.setdefault(key, []).append(obj["raw_text"])
35
+ gen_stripped.setdefault(key, []).append(obj["stripped_text"])
36
+ return refs_raw, refs_stripped, gen_raw, gen_stripped
37
+
38
+
39
+ @torch.no_grad()
40
+ def featurize(
41
+ texts: list[str],
42
+ name: str,
43
+ tokenizer: AutoTokenizer,
44
+ model: AutoModel,
45
+ device: torch.device,
46
+ ) -> np.ndarray:
47
+ chunks: list[np.ndarray] = []
48
+ t0 = time.time()
49
+ for start in range(0, len(texts), BATCH_SIZE):
50
+ batch = texts[start : start + BATCH_SIZE]
51
+ enc = tokenizer(
52
+ batch,
53
+ return_tensors="pt",
54
+ padding=True,
55
+ truncation=True,
56
+ max_length=MAX_LEN,
57
+ return_attention_mask=True,
58
+ ).to(device)
59
+ out = model(
60
+ input_ids=enc["input_ids"],
61
+ attention_mask=enc["attention_mask"],
62
+ output_hidden_states=True,
63
+ return_dict=True,
64
+ )
65
+ hidden = out.hidden_states[-1]
66
+ last_idx = enc["attention_mask"].sum(dim=1) - 1
67
+ feat = hidden[torch.arange(hidden.size(0), device=device), last_idx]
68
+ chunks.append(feat.float().cpu().numpy())
69
+ if (start // BATCH_SIZE) % 10 == 0:
70
+ print(f"{name} {start + len(batch)}/{len(texts)}", flush=True)
71
+ arr = np.concatenate(chunks, axis=0)
72
+ print(f"{name} {arr.shape} time={time.time() - t0:.1f}s", flush=True)
73
+ return arr
74
+
75
+
76
+ def main() -> None:
77
+ refs_raw, refs_stripped, gen_raw, gen_stripped = load_texts()
78
+ print(
79
+ "loaded",
80
+ len(refs_raw),
81
+ {key: len(value) for key, value in gen_raw.items()},
82
+ flush=True,
83
+ )
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
86
+ if tokenizer.pad_token_id is None:
87
+ tokenizer.pad_token = tokenizer.eos_token
88
+ tokenizer.pad_token_id = tokenizer.eos_token_id
89
+ model = AutoModel.from_pretrained(MODEL_PATH, pad_token_id=tokenizer.eos_token_id).to(device).eval()
90
+
91
+ arrays: dict[str, np.ndarray] = {
92
+ "ref_raw": featurize(refs_raw, "ref_raw", tokenizer, model, device),
93
+ "ref_stripped": featurize(refs_stripped, "ref_stripped", tokenizer, model, device),
94
+ }
95
+ for key in sorted(gen_raw):
96
+ arrays[f"gen_{key}_raw"] = featurize(gen_raw[key], f"gen_{key}_raw", tokenizer, model, device)
97
+ arrays[f"gen_{key}_stripped"] = featurize(
98
+ gen_stripped[key],
99
+ f"gen_{key}_stripped",
100
+ tokenizer,
101
+ model,
102
+ device,
103
+ )
104
+ OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
105
+ np.savez_compressed(OUT_PATH, **arrays)
106
+ print(f"DONE {OUT_PATH}", flush=True)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ main()
LTA_openwebtext_dualt/scripts/launch_lta_lm1b_fullycoupled_t5_logitnormal_8gpu.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ exec bash scripts/launch_lta_lm1b_elfaligned_t5_logitnormal_8gpu.sh "$@"
LTA_openwebtext_dualt/scripts/launch_lta_owt_compact_gpt2bpe_v2048_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu.sh ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
7
+ export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
8
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
9
+ export TOKENIZERS_PARALLELISM=false
10
+ export PYTHONUNBUFFERED=1
11
+ export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
12
+ export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
13
+
14
+ # Compact GPT2-BPE v2048, rebuilt as tokenizer-stream chunks:
15
+ # every example is length 1024 except the final tail record in each part.
16
+ COMPACT_VARIANT_LABEL="${COMPACT_VARIANT_LABEL:-compact_gpt2bpe_v2048_stream1024_fullycoupled_mask1_wd0p1_fp32}"
17
+ VOCAB_SIZE="${VOCAB_SIZE:-2048}"
18
+ DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-compact-gpt2bpe-v2048-stream1024-train-minus-100k}"
19
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/lta_tokenizers/owt_compact_gpt2bpe_v2048/tokenizer.json}"
20
+
21
+ NNODES="${NNODES:-1}"
22
+ NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
23
+ NODE_RANK="${NODE_RANK:-0}"
24
+ MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
25
+ MASTER_PORT="${MASTER_PORT:-32249}"
26
+
27
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
28
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
29
+ TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
30
+ WARMUP_STEPS="${WARMUP_STEPS:-2000}"
31
+ NUM_WORKERS="${NUM_WORKERS:-8}"
32
+ DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
33
+ LOG_EVERY="${LOG_EVERY:-50}"
34
+ LATEST_EVERY="${LATEST_EVERY:-1000}"
35
+ EVAL_EVERY="${EVAL_EVERY:-0}"
36
+ ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
37
+ ALLOW_TF32="${ALLOW_TF32:-1}"
38
+ DRY_RUN="${DRY_RUN:-0}"
39
+ RESUME_PATH="${RESUME_PATH:-}"
40
+
41
+ # Fully-coupled fp32 baseline recipe: AdamW, nanoGPT decay groups, RMSNorm,
42
+ # no output bias, hard CE, Dirichlet bridge, full-vocab wrong token.
43
+ LR="${LR:-6e-4}"
44
+ MIN_LR="${MIN_LR:-6e-5}"
45
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
46
+ OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}"
47
+ ADAM_BETA1="${ADAM_BETA1:-0.9}"
48
+ ADAM_BETA2="${ADAM_BETA2:-0.95}"
49
+ ADAM_EPS="${ADAM_EPS:-1e-8}"
50
+ GRAD_CLIP="${GRAD_CLIP:-1.0}"
51
+ EMA_DECAY="${EMA_DECAY:-0.0}"
52
+ EMA_START_STEP="${EMA_START_STEP:-0}"
53
+ LOSS_T_WEIGHT_MODE="${LOSS_T_WEIGHT_MODE:-none}"
54
+ LOSS_T_MIN_WEIGHT="${LOSS_T_MIN_WEIGHT:-0.0}"
55
+ LOSS_T_DROP_BELOW="${LOSS_T_DROP_BELOW:-0.2}"
56
+
57
+ T_SAMPLING_MODE="${T_SAMPLING_MODE:-logit_normal}"
58
+ T_SAMPLING_POWER="${T_SAMPLING_POWER:-1.0}"
59
+ T_SAMPLING_EPS="${T_SAMPLING_EPS:-1e-4}"
60
+ T_SAMPLING_LOGIT_MEAN="${T_SAMPLING_LOGIT_MEAN:--1.5}"
61
+ T_SAMPLING_LOGIT_STD="${T_SAMPLING_LOGIT_STD:-0.8}"
62
+ MIN_MASK_RATIO="${MIN_MASK_RATIO:-1.0}"
63
+ MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
64
+
65
+ sanitize_label() {
66
+ printf "%s" "$1" | sed -e 's/-/m/g' -e 's/\./p/g'
67
+ }
68
+
69
+ T_LOGIT_MEAN_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_MEAN}")"
70
+ T_LOGIT_STD_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_STD}")"
71
+ MIN_MASK_RATIO_LABEL="$(sanitize_label "${MIN_MASK_RATIO}")"
72
+ MAX_MASK_RATIO_LABEL="$(sanitize_label "${MAX_MASK_RATIO}")"
73
+
74
+ RUN_NAME="${RUN_NAME:-lta_owt_compact_gpt2bpe_v2048_stream1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_logitnormal_${T_LOGIT_MEAN_LABEL}_s${T_LOGIT_STD_LABEL}_hardce_mask${MIN_MASK_RATIO_LABEL}-${MAX_MASK_RATIO_LABEL}_fp32_ddit768x12_gbs512_8gpu_1m_$(date +%Y%m%d_%H%M%S)}"
75
+ SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
76
+ LOG_DIR="${LOG_DIR:-logs/compact_gpt2bpe_v2048_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu}"
77
+ LOG_FILE="${LOG_FILE:-${LOG_DIR}/${RUN_NAME}.log}"
78
+
79
+ if [[ -n "${RESUME_PATH}" ]]; then
80
+ ALLOW_EXISTING_SAVE_DIR=1
81
+ if [[ ! -f "${RESUME_PATH}" ]]; then
82
+ echo "Missing resume checkpoint: ${RESUME_PATH}" >&2
83
+ exit 2
84
+ fi
85
+ fi
86
+
87
+ if [[ ! -f "${TOKENIZER_PATH}" ]]; then
88
+ echo "Missing tokenizer: ${TOKENIZER_PATH}" >&2
89
+ exit 2
90
+ fi
91
+ if [[ ! -d "${DATA_PATH}" ]]; then
92
+ echo "Missing tokenized dataset: ${DATA_PATH}" >&2
93
+ echo "Build it with: bash scripts/build_lta_owt_compact_gpt2bpe_stream1024_train_minus_100k_np8.sh" >&2
94
+ exit 2
95
+ fi
96
+
97
+ NUM_EXAMPLES=$(python - <<PY
98
+ import json
99
+ from pathlib import Path
100
+ from datasets import Sequence, load_from_disk
101
+ from datasets.features import features as hf_features
102
+ hf_features._FEATURE_TYPES.setdefault("List", Sequence)
103
+ root = Path("${DATA_PATH}")
104
+ for name in ("elf_multi_part_meta.json", "elf_build_meta.json"):
105
+ path = root / name
106
+ if path.exists():
107
+ print(int(json.loads(path.read_text()).get("num_examples", 0)))
108
+ raise SystemExit
109
+ parts_root = root / "parts"
110
+ if parts_root.is_dir():
111
+ print(sum(len(load_from_disk(str(p))) for p in sorted(parts_root.iterdir()) if p.is_dir()))
112
+ else:
113
+ print(len(load_from_disk(str(root))))
114
+ PY
115
+ )
116
+ WORLD_SIZE=$(( NNODES * NPROC_PER_NODE ))
117
+ GRAD_ACCUM=$(( (GLOBAL_BATCH_SIZE + PER_GPU_BATCH_SIZE * WORLD_SIZE - 1) / (PER_GPU_BATCH_SIZE * WORLD_SIZE) ))
118
+ SAMPLES_PER_RANK=$(( (NUM_EXAMPLES + WORLD_SIZE - 1) / WORLD_SIZE ))
119
+ LOADER_BATCHES_PER_RANK=$(( SAMPLES_PER_RANK / PER_GPU_BATCH_SIZE ))
120
+ STEPS_PER_EPOCH=$(( (LOADER_BATCHES_PER_RANK + GRAD_ACCUM - 1) / GRAD_ACCUM ))
121
+ SAVE_EVERY="${SAVE_EVERY:-${STEPS_PER_EPOCH}}"
122
+
123
+ if [[ -f "${SAVE_DIR}/args.json" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
124
+ echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
125
+ echo "Use a new RUN_NAME/SAVE_DIR or set ALLOW_EXISTING_SAVE_DIR=1 intentionally." >&2
126
+ exit 2
127
+ fi
128
+
129
+ mkdir -p "${LOG_DIR}" "${SAVE_DIR}"
130
+
131
+ TF32_FLAG="--allow_tf32"
132
+ TF32_LABEL="true"
133
+ if [[ "${ALLOW_TF32}" == "0" || "${ALLOW_TF32}" == "false" || "${ALLOW_TF32}" == "False" ]]; then
134
+ TF32_FLAG="--no-allow_tf32"
135
+ TF32_LABEL="false"
136
+ fi
137
+
138
+ echo "[launch] method=${COMPACT_VARIANT_LABEL} host=$(hostname) time=$(date -Iseconds)"
139
+ echo "[launch] run_name=${RUN_NAME}"
140
+ echo "[launch] save_dir=${SAVE_DIR}"
141
+ echo "[launch] log_file=${LOG_FILE}"
142
+ echo "[launch] data_path=${DATA_PATH}"
143
+ echo "[launch] tokenizer=${TOKENIZER_PATH}"
144
+ echo "[launch] examples=${NUM_EXAMPLES} total_steps=${TOTAL_STEPS} world_size=${WORLD_SIZE} grad_accum=${GRAD_ACCUM} steps_per_epoch=${STEPS_PER_EPOCH} save_every=${SAVE_EVERY}"
145
+ echo "[launch] optimizer=adamw lr=${LR} min_lr=${MIN_LR} wd=${WEIGHT_DECAY} output_wd=${OUTPUT_WEIGHT_DECAY} param_groups=nanogpt ema=${EMA_DECAY}"
146
+ echo "[launch] fp32=true bf16=false tf32=${TF32_LABEL} norm_type=rmsnorm output_bias=false batch=${GLOBAL_BATCH_SIZE} per_gpu=${PER_GPU_BATCH_SIZE}"
147
+ echo "[launch] target_loss=hard_ce t_sampling_mode=${T_SAMPLING_MODE} t_sampling_logit_mean=${T_SAMPLING_LOGIT_MEAN} t_sampling_logit_std=${T_SAMPLING_LOGIT_STD} mask_ratio=${MIN_MASK_RATIO}->${MAX_MASK_RATIO}"
148
+ if [[ -n "${RESUME_PATH}" ]]; then
149
+ echo "[launch] resume_path=${RESUME_PATH}"
150
+ fi
151
+
152
+ if [[ "${DRY_RUN}" == "1" || "${DRY_RUN}" == "true" || "${DRY_RUN}" == "True" ]]; then
153
+ echo "[launch] DRY_RUN=1, validated tokenizer/data/schedule; skipping torchrun."
154
+ exit 0
155
+ fi
156
+
157
+ RESUME_ARGS=()
158
+ if [[ -n "${RESUME_PATH}" ]]; then
159
+ RESUME_ARGS+=(--resume_path "${RESUME_PATH}")
160
+ fi
161
+
162
+ python -m torch.distributed.run \
163
+ --nnodes="${NNODES}" \
164
+ --nproc_per_node="${NPROC_PER_NODE}" \
165
+ --node_rank="${NODE_RANK}" \
166
+ --master_addr="${MASTER_ADDR}" \
167
+ --master_port="${MASTER_PORT}" \
168
+ train.py \
169
+ "${RESUME_ARGS[@]}" \
170
+ --data_path "${DATA_PATH}" \
171
+ --tokenized_hf \
172
+ --tokenized_pad_token pad \
173
+ --tokenizer_path "${TOKENIZER_PATH}" \
174
+ --save_dir "${SAVE_DIR}" \
175
+ --max_len 1024 \
176
+ --batch_size "${PER_GPU_BATCH_SIZE}" \
177
+ --global_batch_size "${GLOBAL_BATCH_SIZE}" \
178
+ --num_workers "${NUM_WORKERS}" \
179
+ --dataloader_prefetch_factor "${DATALOADER_PREFETCH_FACTOR}" \
180
+ --epochs 0 \
181
+ --total_steps "${TOTAL_STEPS}" \
182
+ --warmup_steps "${WARMUP_STEPS}" \
183
+ --log_every "${LOG_EVERY}" \
184
+ --eval_every "${EVAL_EVERY}" \
185
+ --save_every "${SAVE_EVERY}" \
186
+ --latest_every "${LATEST_EVERY}" \
187
+ --optimizer adamw \
188
+ --lr "${LR}" \
189
+ --lr_schedule cosine \
190
+ --min_lr "${MIN_LR}" \
191
+ --weight_decay "${WEIGHT_DECAY}" \
192
+ --output_weight_decay "${OUTPUT_WEIGHT_DECAY}" \
193
+ --adamw_param_groups nanogpt \
194
+ --adam_beta1 "${ADAM_BETA1}" \
195
+ --adam_beta2 "${ADAM_BETA2}" \
196
+ --adam_eps "${ADAM_EPS}" \
197
+ --ema_decay "${EMA_DECAY}" \
198
+ --ema_start_step "${EMA_START_STEP}" \
199
+ --grad_clip "${GRAD_CLIP}" \
200
+ --seed 123 \
201
+ --d_model 768 \
202
+ --cond_dim 128 \
203
+ --n_layers 12 \
204
+ --n_heads 12 \
205
+ --dim_ff 3072 \
206
+ --dropout 0.0 \
207
+ --no-output_bias \
208
+ --norm_type rmsnorm \
209
+ --model_type ddit \
210
+ --state_format prob \
211
+ --bridge dirichlet \
212
+ --target_loss hard_ce \
213
+ --loss_t_weight_mode "${LOSS_T_WEIGHT_MODE}" \
214
+ --loss_t_min_weight "${LOSS_T_MIN_WEIGHT}" \
215
+ --loss_t_drop_below "${LOSS_T_DROP_BELOW}" \
216
+ --target_prob 1.0 \
217
+ --min_t 0.0 \
218
+ --max_t 1.0 \
219
+ --t_sampling_mode "${T_SAMPLING_MODE}" \
220
+ --t_sampling_power "${T_SAMPLING_POWER}" \
221
+ --t_sampling_logit_mean "${T_SAMPLING_LOGIT_MEAN}" \
222
+ --t_sampling_logit_std "${T_SAMPLING_LOGIT_STD}" \
223
+ --t_sampling_eps "${T_SAMPLING_EPS}" \
224
+ --dual_t \
225
+ --corrupt_t_mode same \
226
+ --corrupt_min_t 0.0 \
227
+ --corrupt_max_t 1.0 \
228
+ --min_mask_ratio "${MIN_MASK_RATIO}" \
229
+ --max_mask_ratio "${MAX_MASK_RATIO}" \
230
+ --wrong_token_replace_prob 1.0 \
231
+ --wrong_token_schedule linear_t \
232
+ --wrong_token_exp_k 1.0 \
233
+ --dirichlet_concentration_min 1.0 \
234
+ --dirichlet_concentration_max 1024 \
235
+ --dirichlet_endpoint_mode categorical_dual_t \
236
+ --dirichlet_semantic_t_mode same \
237
+ --dirichlet_semantic_t_value 0.0 \
238
+ --categorical_wrong_from_full_vocab \
239
+ --simplex_bridge_sampler dirichlet \
240
+ --eps 1e-8 \
241
+ --infer_steps 1024 \
242
+ --decode_damping 1.0 \
243
+ --max_gamma 1.0 \
244
+ --decode_solver flowmap \
245
+ --noise_init logistic_normal \
246
+ --bridge_noise_init logistic_normal \
247
+ --noise_sigma -1 \
248
+ "${TF32_FLAG}" \
249
+ --ddp_gradient_as_bucket_view \
250
+ 2>&1 | tee -a "${LOG_FILE}"
LTA_openwebtext_dualt/scripts/launch_lta_owt_fullycoupled_elftsampling_mask1_wd0p1_fp32_8gpu.sh ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
7
+ export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
8
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
9
+ export TOKENIZERS_PARALLELISM=false
10
+ export PYTHONUNBUFFERED=1
11
+ export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
12
+ export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
13
+
14
+ # Fully-coupled OWT baseline with FP32 params/activations and nanoGPT-style AdamW decay:
15
+ # decay: matrix / embedding params (p.dim() >= 2)
16
+ # no decay: bias / norm / 1D params
17
+ # This intentionally does NOT pass --bf16. TF32 is enabled by default so H200 uses
18
+ # Tensor Cores; set ALLOW_TF32=0 for strict FP32 debugging.
19
+
20
+ T_SAMPLING_MODE="${T_SAMPLING_MODE:-logit_normal}"
21
+ T_SAMPLING_POWER="${T_SAMPLING_POWER:-1.0}"
22
+ T_SAMPLING_EPS="${T_SAMPLING_EPS:-1e-4}"
23
+ T_SAMPLING_LOGIT_MEAN="${T_SAMPLING_LOGIT_MEAN:--1.5}"
24
+ T_SAMPLING_LOGIT_STD="${T_SAMPLING_LOGIT_STD:-0.8}"
25
+ MIN_MASK_RATIO="${MIN_MASK_RATIO:-1.0}"
26
+ MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
27
+
28
+ sanitize_label() {
29
+ printf "%s" "$1" | sed -e 's/-/m/g' -e 's/\./p/g'
30
+ }
31
+
32
+ T_LOGIT_MEAN_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_MEAN}")"
33
+ T_LOGIT_STD_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_STD}")"
34
+ MIN_MASK_RATIO_LABEL="$(sanitize_label "${MIN_MASK_RATIO}")"
35
+ MAX_MASK_RATIO_LABEL="$(sanitize_label "${MAX_MASK_RATIO}")"
36
+
37
+ RUN_NAME="${RUN_NAME:-lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_elftsampling_${T_LOGIT_MEAN_LABEL}_s${T_LOGIT_STD_LABEL}_hardce_mask${MIN_MASK_RATIO_LABEL}-${MAX_MASK_RATIO_LABEL}_nanogpt_fp32_ddit768x12_gbs512_8gpu_1m_$(date +%Y%m%d_%H%M%S)}"
38
+ SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
39
+ LOG_DIR="${LOG_DIR:-logs/fullycoupled_elftsampling_mask1_wd0p1_fp32_8gpu}"
40
+ LOG_FILE="${LOG_FILE:-${LOG_DIR}/${RUN_NAME}.log}"
41
+
42
+ DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext}"
43
+ OWT_CACHE="${OWT_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k}"
44
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
45
+
46
+ NNODES="${NNODES:-1}"
47
+ NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
48
+ NODE_RANK="${NODE_RANK:-0}"
49
+ MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
50
+ MASTER_PORT="${MASTER_PORT:-31997}"
51
+
52
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
53
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
54
+ TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
55
+ WARMUP_STEPS="${WARMUP_STEPS:-2000}"
56
+ NUM_WORKERS="${NUM_WORKERS:-8}"
57
+ DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
58
+ LOG_EVERY="${LOG_EVERY:-50}"
59
+ SAVE_EVERY="${SAVE_EVERY:-50000}"
60
+ LATEST_EVERY="${LATEST_EVERY:-1000}"
61
+ EVAL_EVERY="${EVAL_EVERY:-0}"
62
+ ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
63
+ ALLOW_TF32="${ALLOW_TF32:-1}"
64
+ DRY_RUN="${DRY_RUN:-0}"
65
+
66
+ LR="${LR:-6e-4}"
67
+ MIN_LR="${MIN_LR:-6e-5}"
68
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
69
+ OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}"
70
+ ADAM_BETA1="${ADAM_BETA1:-0.9}"
71
+ ADAM_BETA2="${ADAM_BETA2:-0.95}"
72
+ ADAM_EPS="${ADAM_EPS:-1e-8}"
73
+ GRAD_CLIP="${GRAD_CLIP:-1.0}"
74
+ EMA_DECAY="${EMA_DECAY:-0.0}"
75
+ EMA_START_STEP="${EMA_START_STEP:-0}"
76
+ LOSS_T_WEIGHT_MODE="${LOSS_T_WEIGHT_MODE:-none}"
77
+ LOSS_T_MIN_WEIGHT="${LOSS_T_MIN_WEIGHT:-0.0}"
78
+ LOSS_T_DROP_BELOW="${LOSS_T_DROP_BELOW:-0.2}"
79
+
80
+ if [[ -f "${SAVE_DIR}/args.json" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
81
+ echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
82
+ echo "Use a new RUN_NAME/SAVE_DIR or set ALLOW_EXISTING_SAVE_DIR=1 intentionally." >&2
83
+ exit 2
84
+ fi
85
+
86
+ mkdir -p "${LOG_DIR}" "${SAVE_DIR}"
87
+
88
+ TF32_FLAG="--allow_tf32"
89
+ TF32_LABEL="true"
90
+ if [[ "${ALLOW_TF32}" == "0" || "${ALLOW_TF32}" == "false" || "${ALLOW_TF32}" == "False" ]]; then
91
+ TF32_FLAG="--no-allow_tf32"
92
+ TF32_LABEL="false"
93
+ fi
94
+
95
+ echo "[launch] method=owt_fullycoupled_adamw_wd0p1_nanogpt_fp32 host=$(hostname) time=$(date -Iseconds)"
96
+ echo "[launch] run_name=${RUN_NAME}"
97
+ echo "[launch] save_dir=${SAVE_DIR}"
98
+ echo "[launch] log_file=${LOG_FILE}"
99
+ echo "[launch] data_path=${DATA_PATH}"
100
+ echo "[launch] owt_cache=${OWT_CACHE}"
101
+ echo "[launch] optimizer=adamw lr=${LR} min_lr=${MIN_LR} wd=${WEIGHT_DECAY} output_wd=${OUTPUT_WEIGHT_DECAY} param_groups=nanogpt ema=${EMA_DECAY}"
102
+ echo "[launch] fp32=true bf16=false tf32=${TF32_LABEL} norm_type=rmsnorm output_bias=false batch=${GLOBAL_BATCH_SIZE} per_gpu=${PER_GPU_BATCH_SIZE}"
103
+ echo "[launch] loss_t_weight_mode=${LOSS_T_WEIGHT_MODE} loss_t_min_weight=${LOSS_T_MIN_WEIGHT} loss_t_drop_below=${LOSS_T_DROP_BELOW}"
104
+ echo "[launch] target_loss=hard_ce t_sampling_mode=${T_SAMPLING_MODE} t_sampling_logit_mean=${T_SAMPLING_LOGIT_MEAN} t_sampling_logit_std=${T_SAMPLING_LOGIT_STD} t_sampling_power=${T_SAMPLING_POWER} t_sampling_eps=${T_SAMPLING_EPS} mask_ratio=${MIN_MASK_RATIO}->${MAX_MASK_RATIO}"
105
+
106
+ if [[ "${DRY_RUN}" == "1" || "${DRY_RUN}" == "true" || "${DRY_RUN}" == "True" ]]; then
107
+ echo "[launch] DRY_RUN=1, validated launch parameters; skipping torchrun."
108
+ exit 0
109
+ fi
110
+
111
+ python -m torch.distributed.run \
112
+ --nnodes="${NNODES}" \
113
+ --nproc_per_node="${NPROC_PER_NODE}" \
114
+ --node_rank="${NODE_RANK}" \
115
+ --master_addr="${MASTER_ADDR}" \
116
+ --master_port="${MASTER_PORT}" \
117
+ train.py \
118
+ --data_path "${DATA_PATH}" \
119
+ --openwebtext_split train_minus_100k \
120
+ --text_column text \
121
+ --detokenizer auto \
122
+ --tokenizer_path "${TOKENIZER_PATH}" \
123
+ --save_dir "${SAVE_DIR}" \
124
+ --wrap \
125
+ --wrap_mode stream \
126
+ --owt_cached_chunks \
127
+ --owt_chunk_cache_dir "${OWT_CACHE}" \
128
+ --max_len 1024 \
129
+ --batch_size "${PER_GPU_BATCH_SIZE}" \
130
+ --global_batch_size "${GLOBAL_BATCH_SIZE}" \
131
+ --num_workers "${NUM_WORKERS}" \
132
+ --dataloader_prefetch_factor "${DATALOADER_PREFETCH_FACTOR}" \
133
+ --total_steps "${TOTAL_STEPS}" \
134
+ --warmup_steps "${WARMUP_STEPS}" \
135
+ --log_every "${LOG_EVERY}" \
136
+ --eval_every "${EVAL_EVERY}" \
137
+ --save_every "${SAVE_EVERY}" \
138
+ --latest_every "${LATEST_EVERY}" \
139
+ --optimizer adamw \
140
+ --lr "${LR}" \
141
+ --lr_schedule cosine \
142
+ --min_lr "${MIN_LR}" \
143
+ --weight_decay "${WEIGHT_DECAY}" \
144
+ --output_weight_decay "${OUTPUT_WEIGHT_DECAY}" \
145
+ --adamw_param_groups nanogpt \
146
+ --adam_beta1 "${ADAM_BETA1}" \
147
+ --adam_beta2 "${ADAM_BETA2}" \
148
+ --adam_eps "${ADAM_EPS}" \
149
+ --ema_decay "${EMA_DECAY}" \
150
+ --ema_start_step "${EMA_START_STEP}" \
151
+ --grad_clip "${GRAD_CLIP}" \
152
+ --seed 123 \
153
+ --d_model 768 \
154
+ --cond_dim 128 \
155
+ --n_layers 12 \
156
+ --n_heads 12 \
157
+ --dim_ff 3072 \
158
+ --dropout 0.0 \
159
+ --no-output_bias \
160
+ --norm_type rmsnorm \
161
+ --model_type ddit \
162
+ --state_format prob \
163
+ --bridge dirichlet \
164
+ --target_loss hard_ce \
165
+ --loss_t_weight_mode "${LOSS_T_WEIGHT_MODE}" \
166
+ --loss_t_min_weight "${LOSS_T_MIN_WEIGHT}" \
167
+ --loss_t_drop_below "${LOSS_T_DROP_BELOW}" \
168
+ --target_prob 1.0 \
169
+ --min_t 0.0 \
170
+ --max_t 1.0 \
171
+ --t_sampling_mode "${T_SAMPLING_MODE}" \
172
+ --t_sampling_power "${T_SAMPLING_POWER}" \
173
+ --t_sampling_logit_mean "${T_SAMPLING_LOGIT_MEAN}" \
174
+ --t_sampling_logit_std "${T_SAMPLING_LOGIT_STD}" \
175
+ --t_sampling_eps "${T_SAMPLING_EPS}" \
176
+ --dual_t \
177
+ --corrupt_t_mode same \
178
+ --corrupt_min_t 0.0 \
179
+ --corrupt_max_t 1.0 \
180
+ --min_mask_ratio "${MIN_MASK_RATIO}" \
181
+ --max_mask_ratio "${MAX_MASK_RATIO}" \
182
+ --wrong_token_replace_prob 1.0 \
183
+ --wrong_token_schedule linear_t \
184
+ --wrong_token_exp_k 1.0 \
185
+ --dirichlet_concentration_min 1.0 \
186
+ --dirichlet_concentration_max 1024 \
187
+ --dirichlet_endpoint_mode categorical_dual_t \
188
+ --dirichlet_semantic_t_mode same \
189
+ --dirichlet_semantic_t_value 0.0 \
190
+ --categorical_wrong_from_full_vocab \
191
+ --simplex_bridge_sampler dirichlet \
192
+ --eps 1e-8 \
193
+ --infer_steps 1024 \
194
+ --decode_damping 1.0 \
195
+ --max_gamma 1.0 \
196
+ --decode_solver flowmap \
197
+ --noise_init logistic_normal \
198
+ --bridge_noise_init logistic_normal \
199
+ --noise_sigma -1 \
200
+ "${TF32_FLAG}" \
201
+ --ddp_gradient_as_bucket_view \
202
+ 2>&1 | tee -a "${LOG_FILE}"
LTA_openwebtext_dualt/scripts/launch_lta_owt_fullycoupled_uniform_mask1_swiglu_wd0p1_fp32_8gpu.sh ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
7
+ export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
8
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
9
+ export TOKENIZERS_PARALLELISM=false
10
+ export PYTHONUNBUFFERED=1
11
+ export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
12
+ export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
13
+
14
+ # Fully-coupled OWT baseline with FP32 params/activations and nanoGPT-style AdamW decay:
15
+ # decay: matrix / embedding params (p.dim() >= 2)
16
+ # no decay: bias / norm / 1D params
17
+ # This intentionally does NOT pass --bf16. TF32 is enabled by default so H200 uses
18
+ # Tensor Cores; set ALLOW_TF32=0 for strict FP32 debugging.
19
+
20
+ T_SAMPLING_MODE="${T_SAMPLING_MODE:-uniform}"
21
+ T_SAMPLING_POWER="${T_SAMPLING_POWER:-1.0}"
22
+ T_SAMPLING_EPS="${T_SAMPLING_EPS:-1e-4}"
23
+ T_SAMPLING_LOGIT_MEAN="${T_SAMPLING_LOGIT_MEAN:--0.22}"
24
+ T_SAMPLING_LOGIT_STD="${T_SAMPLING_LOGIT_STD:-0.5}"
25
+ MIN_MASK_RATIO="${MIN_MASK_RATIO:-1.0}"
26
+ MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
27
+ DDIT_MLP_TYPE="${DDIT_MLP_TYPE:-swiglu}"
28
+
29
+ sanitize_label() {
30
+ printf "%s" "$1" | sed -e 's/-/m/g' -e 's/\./p/g'
31
+ }
32
+
33
+ T_LOGIT_MEAN_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_MEAN}")"
34
+ T_LOGIT_STD_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_STD}")"
35
+ MIN_MASK_RATIO_LABEL="$(sanitize_label "${MIN_MASK_RATIO}")"
36
+ MAX_MASK_RATIO_LABEL="$(sanitize_label "${MAX_MASK_RATIO}")"
37
+
38
+ RUN_NAME="${RUN_NAME:-lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_${DDIT_MLP_TYPE}_adamw_wd0p1_uniformt_hardce_mask${MIN_MASK_RATIO_LABEL}-${MAX_MASK_RATIO_LABEL}_nanogpt_fp32_ddit768x12_gbs512_8gpu_1m_$(date +%Y%m%d_%H%M%S)}"
39
+ SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
40
+ LOG_DIR="${LOG_DIR:-logs/fullycoupled_uniform_mask1_swiglu_wd0p1_fp32_8gpu}"
41
+ LOG_FILE="${LOG_FILE:-${LOG_DIR}/${RUN_NAME}.log}"
42
+
43
+ DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext}"
44
+ OWT_CACHE="${OWT_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k}"
45
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
46
+
47
+ NNODES="${NNODES:-1}"
48
+ NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
49
+ NODE_RANK="${NODE_RANK:-0}"
50
+ MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
51
+ MASTER_PORT="${MASTER_PORT:-31997}"
52
+
53
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
54
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
55
+ TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
56
+ WARMUP_STEPS="${WARMUP_STEPS:-2000}"
57
+ NUM_WORKERS="${NUM_WORKERS:-8}"
58
+ DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
59
+ LOG_EVERY="${LOG_EVERY:-50}"
60
+ SAVE_EVERY="${SAVE_EVERY:-50000}"
61
+ LATEST_EVERY="${LATEST_EVERY:-1000}"
62
+ EVAL_EVERY="${EVAL_EVERY:-0}"
63
+ ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
64
+ ALLOW_TF32="${ALLOW_TF32:-1}"
65
+ DRY_RUN="${DRY_RUN:-0}"
66
+
67
+ LR="${LR:-6e-4}"
68
+ MIN_LR="${MIN_LR:-6e-5}"
69
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
70
+ OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}"
71
+ ADAM_BETA1="${ADAM_BETA1:-0.9}"
72
+ ADAM_BETA2="${ADAM_BETA2:-0.95}"
73
+ ADAM_EPS="${ADAM_EPS:-1e-8}"
74
+ GRAD_CLIP="${GRAD_CLIP:-1.0}"
75
+ EMA_DECAY="${EMA_DECAY:-0.0}"
76
+ EMA_START_STEP="${EMA_START_STEP:-0}"
77
+ LOSS_T_WEIGHT_MODE="${LOSS_T_WEIGHT_MODE:-none}"
78
+ LOSS_T_MIN_WEIGHT="${LOSS_T_MIN_WEIGHT:-0.0}"
79
+ LOSS_T_DROP_BELOW="${LOSS_T_DROP_BELOW:-0.2}"
80
+
81
+ if [[ -f "${SAVE_DIR}/args.json" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
82
+ echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
83
+ echo "Use a new RUN_NAME/SAVE_DIR or set ALLOW_EXISTING_SAVE_DIR=1 intentionally." >&2
84
+ exit 2
85
+ fi
86
+
87
+ mkdir -p "${LOG_DIR}" "${SAVE_DIR}"
88
+
89
+ TF32_FLAG="--allow_tf32"
90
+ TF32_LABEL="true"
91
+ if [[ "${ALLOW_TF32}" == "0" || "${ALLOW_TF32}" == "false" || "${ALLOW_TF32}" == "False" ]]; then
92
+ TF32_FLAG="--no-allow_tf32"
93
+ TF32_LABEL="false"
94
+ fi
95
+
96
+ echo "[launch] method=owt_fullycoupled_adamw_wd0p1_nanogpt_fp32 host=$(hostname) time=$(date -Iseconds)"
97
+ echo "[launch] run_name=${RUN_NAME}"
98
+ echo "[launch] save_dir=${SAVE_DIR}"
99
+ echo "[launch] log_file=${LOG_FILE}"
100
+ echo "[launch] data_path=${DATA_PATH}"
101
+ echo "[launch] owt_cache=${OWT_CACHE}"
102
+ echo "[launch] optimizer=adamw lr=${LR} min_lr=${MIN_LR} wd=${WEIGHT_DECAY} output_wd=${OUTPUT_WEIGHT_DECAY} param_groups=nanogpt ema=${EMA_DECAY}"
103
+ echo "[launch] fp32=true bf16=false tf32=${TF32_LABEL} norm_type=rmsnorm output_bias=false ddit_mlp_type=${DDIT_MLP_TYPE} batch=${GLOBAL_BATCH_SIZE} per_gpu=${PER_GPU_BATCH_SIZE}"
104
+ echo "[launch] loss_t_weight_mode=${LOSS_T_WEIGHT_MODE} loss_t_min_weight=${LOSS_T_MIN_WEIGHT} loss_t_drop_below=${LOSS_T_DROP_BELOW}"
105
+ echo "[launch] target_loss=hard_ce t_sampling_mode=${T_SAMPLING_MODE} t_sampling_logit_mean=${T_SAMPLING_LOGIT_MEAN} t_sampling_logit_std=${T_SAMPLING_LOGIT_STD} t_sampling_power=${T_SAMPLING_POWER} t_sampling_eps=${T_SAMPLING_EPS} mask_ratio=${MIN_MASK_RATIO}->${MAX_MASK_RATIO}"
106
+
107
+ if [[ "${DRY_RUN}" == "1" || "${DRY_RUN}" == "true" || "${DRY_RUN}" == "True" ]]; then
108
+ echo "[launch] DRY_RUN=1, validated launch parameters; skipping torchrun."
109
+ exit 0
110
+ fi
111
+
112
+ python -m torch.distributed.run \
113
+ --nnodes="${NNODES}" \
114
+ --nproc_per_node="${NPROC_PER_NODE}" \
115
+ --node_rank="${NODE_RANK}" \
116
+ --master_addr="${MASTER_ADDR}" \
117
+ --master_port="${MASTER_PORT}" \
118
+ train.py \
119
+ --data_path "${DATA_PATH}" \
120
+ --openwebtext_split train_minus_100k \
121
+ --text_column text \
122
+ --detokenizer auto \
123
+ --tokenizer_path "${TOKENIZER_PATH}" \
124
+ --save_dir "${SAVE_DIR}" \
125
+ --wrap \
126
+ --wrap_mode stream \
127
+ --owt_cached_chunks \
128
+ --owt_chunk_cache_dir "${OWT_CACHE}" \
129
+ --max_len 1024 \
130
+ --batch_size "${PER_GPU_BATCH_SIZE}" \
131
+ --global_batch_size "${GLOBAL_BATCH_SIZE}" \
132
+ --num_workers "${NUM_WORKERS}" \
133
+ --dataloader_prefetch_factor "${DATALOADER_PREFETCH_FACTOR}" \
134
+ --total_steps "${TOTAL_STEPS}" \
135
+ --warmup_steps "${WARMUP_STEPS}" \
136
+ --log_every "${LOG_EVERY}" \
137
+ --eval_every "${EVAL_EVERY}" \
138
+ --save_every "${SAVE_EVERY}" \
139
+ --latest_every "${LATEST_EVERY}" \
140
+ --optimizer adamw \
141
+ --lr "${LR}" \
142
+ --lr_schedule cosine \
143
+ --min_lr "${MIN_LR}" \
144
+ --weight_decay "${WEIGHT_DECAY}" \
145
+ --output_weight_decay "${OUTPUT_WEIGHT_DECAY}" \
146
+ --adamw_param_groups nanogpt \
147
+ --adam_beta1 "${ADAM_BETA1}" \
148
+ --adam_beta2 "${ADAM_BETA2}" \
149
+ --adam_eps "${ADAM_EPS}" \
150
+ --ema_decay "${EMA_DECAY}" \
151
+ --ema_start_step "${EMA_START_STEP}" \
152
+ --grad_clip "${GRAD_CLIP}" \
153
+ --seed 123 \
154
+ --d_model 768 \
155
+ --cond_dim 128 \
156
+ --n_layers 12 \
157
+ --n_heads 12 \
158
+ --dim_ff 3072 \
159
+ --dropout 0.0 \
160
+ --no-output_bias \
161
+ --norm_type rmsnorm \
162
+ --model_type ddit \
163
+ --ddit_mlp_type "${DDIT_MLP_TYPE}" \
164
+ --state_format prob \
165
+ --bridge dirichlet \
166
+ --target_loss hard_ce \
167
+ --loss_t_weight_mode "${LOSS_T_WEIGHT_MODE}" \
168
+ --loss_t_min_weight "${LOSS_T_MIN_WEIGHT}" \
169
+ --loss_t_drop_below "${LOSS_T_DROP_BELOW}" \
170
+ --target_prob 1.0 \
171
+ --min_t 0.0 \
172
+ --max_t 1.0 \
173
+ --t_sampling_mode "${T_SAMPLING_MODE}" \
174
+ --t_sampling_power "${T_SAMPLING_POWER}" \
175
+ --t_sampling_logit_mean "${T_SAMPLING_LOGIT_MEAN}" \
176
+ --t_sampling_logit_std "${T_SAMPLING_LOGIT_STD}" \
177
+ --t_sampling_eps "${T_SAMPLING_EPS}" \
178
+ --dual_t \
179
+ --corrupt_t_mode same \
180
+ --corrupt_min_t 0.0 \
181
+ --corrupt_max_t 1.0 \
182
+ --min_mask_ratio "${MIN_MASK_RATIO}" \
183
+ --max_mask_ratio "${MAX_MASK_RATIO}" \
184
+ --wrong_token_replace_prob 1.0 \
185
+ --wrong_token_schedule linear_t \
186
+ --wrong_token_exp_k 1.0 \
187
+ --dirichlet_concentration_min 1.0 \
188
+ --dirichlet_concentration_max 1024 \
189
+ --dirichlet_endpoint_mode categorical_dual_t \
190
+ --dirichlet_semantic_t_mode same \
191
+ --dirichlet_semantic_t_value 0.0 \
192
+ --categorical_wrong_from_full_vocab \
193
+ --simplex_bridge_sampler dirichlet \
194
+ --eps 1e-8 \
195
+ --infer_steps 1024 \
196
+ --decode_damping 1.0 \
197
+ --max_gamma 1.0 \
198
+ --decode_solver flowmap \
199
+ --noise_init logistic_normal \
200
+ --bridge_noise_init logistic_normal \
201
+ --noise_sigma -1 \
202
+ "${TF32_FLAG}" \
203
+ --ddp_gradient_as_bucket_view \
204
+ 2>&1 | tee -a "${LOG_FILE}"
LTA_openwebtext_dualt/scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
7
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}"
8
+ export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
9
+ export TOKENIZERS_PARALLELISM=false
10
+ export PYTHONUNBUFFERED=1
11
+ export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
12
+ export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
13
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
14
+
15
+ STAMP="$(date +%Y%m%d_%H%M%S)"
16
+ DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext}"
17
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
18
+ OWT_CHUNK_CACHE_DIR="${OWT_CHUNK_CACHE_DIR:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k}"
19
+ OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK:-0}"
20
+
21
+ NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
22
+ MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
23
+ MASTER_PORT="${MASTER_PORT:-32251}"
24
+
25
+ MAX_LEN="${MAX_LEN:-1024}"
26
+ LOWK_CLEAN_TOKENS="${LOWK_CLEAN_TOKENS:-64,128,256}"
27
+ CLEAN_STATE_MODE="${CLEAN_STATE_MODE:-onehot}"
28
+ MASK_MIXTURE_LOWK_PROB="${MASK_MIXTURE_LOWK_PROB:-1.0}"
29
+ MASK_MIXTURE_ALL_PROB="${MASK_MIXTURE_ALL_PROB:-0.0}"
30
+ VOCAB_SIZE_OVERRIDE="${VOCAB_SIZE_OVERRIDE:-0}"
31
+ D_MODEL="${D_MODEL:-768}"
32
+ COND_DIM="${COND_DIM:-128}"
33
+ N_LAYERS="${N_LAYERS:-12}"
34
+ N_HEADS="${N_HEADS:-12}"
35
+ DIM_FF="${DIM_FF:-3072}"
36
+ TOTAL_STEPS="${TOTAL_STEPS:-500}"
37
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
38
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
39
+ NUM_WORKERS="${NUM_WORKERS:-4}"
40
+ DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
41
+ MUON_IMPL="${MUON_IMPL:-legacy}"
42
+ OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}"
43
+ MIN_MASK_RATIO="${MIN_MASK_RATIO:-0.1}"
44
+ MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
45
+ MASK_RATIO_FLOOR_SCHEDULE="${MASK_RATIO_FLOOR_SCHEDULE:-none}"
46
+ DIRICHLET_CONCENTRATION_MIN="${DIRICHLET_CONCENTRATION_MIN:-1.0}"
47
+ DIRICHLET_CONCENTRATION_MAX="${DIRICHLET_CONCENTRATION_MAX:-1024}"
48
+ DIRICHLET_ENDPOINT_MODE="${DIRICHLET_ENDPOINT_MODE:-categorical_dual_t}"
49
+ DIRICHLET_SEMANTIC_T_MODE="${DIRICHLET_SEMANTIC_T_MODE:-same}"
50
+ DIRICHLET_SEMANTIC_T_VALUE="${DIRICHLET_SEMANTIC_T_VALUE:-0.0}"
51
+ SIMPLEX_BRIDGE_SAMPLER="${SIMPLEX_BRIDGE_SAMPLER:-dirichlet}"
52
+ ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA="${ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA:-0.0}"
53
+ CATEGORICAL_WRONG_BASIN_TOKEN_IDS="${CATEGORICAL_WRONG_BASIN_TOKEN_IDS:-}"
54
+ CATEGORICAL_WRONG_BASIN_PROB="${CATEGORICAL_WRONG_BASIN_PROB:-0.0}"
55
+ CATEGORICAL_WRONG_UNIGRAM_PROB="${CATEGORICAL_WRONG_UNIGRAM_PROB:-0.0}"
56
+ CATEGORICAL_WRONG_UNIFORM_PROB="${CATEGORICAL_WRONG_UNIFORM_PROB:-0.0}"
57
+ CATEGORICAL_WRONG_PROB_FLOOR="${CATEGORICAL_WRONG_PROB_FLOOR:-0.0}"
58
+ CATEGORICAL_WRONG_BASIN_SHARED_PROB="${CATEGORICAL_WRONG_BASIN_SHARED_PROB:-0.0}"
59
+ CATEGORICAL_WRONG_UNIGRAM_SHARED_PROB="${CATEGORICAL_WRONG_UNIGRAM_SHARED_PROB:-0.0}"
60
+ LOGISTIC_NORMAL_SIGMA_MIN="${LOGISTIC_NORMAL_SIGMA_MIN:-0.1}"
61
+ LOGISTIC_NORMAL_SIGMA_MAX="${LOGISTIC_NORMAL_SIGMA_MAX:-1.0}"
62
+ LOGISTIC_NORMAL_TAU_MIN="${LOGISTIC_NORMAL_TAU_MIN:-1.0}"
63
+ LOGISTIC_NORMAL_TAU_MAX="${LOGISTIC_NORMAL_TAU_MAX:-1.0}"
64
+ ROLLOUT_TRAIN_PROB="${ROLLOUT_TRAIN_PROB:-0.0}"
65
+ ROLLOUT_TRAIN_STEPS="${ROLLOUT_TRAIN_STEPS:-1}"
66
+ ROLLOUT_TRAIN_STEPS_MIN="${ROLLOUT_TRAIN_STEPS_MIN:--1}"
67
+ ROLLOUT_TRAIN_INFER_STEPS="${ROLLOUT_TRAIN_INFER_STEPS:-64}"
68
+ ROLLOUT_TRAIN_TIME_MODE="${ROLLOUT_TRAIN_TIME_MODE:-fixed_steps}"
69
+ ROLLOUT_TRAIN_S_DIST="${ROLLOUT_TRAIN_S_DIST:-uniform}"
70
+ ROLLOUT_TRAIN_S_MIN_FRAC="${ROLLOUT_TRAIN_S_MIN_FRAC:-0.0}"
71
+ ROLLOUT_TRAIN_S_MAX_FRAC="${ROLLOUT_TRAIN_S_MAX_FRAC:-0.125}"
72
+ ROLLOUT_TRAIN_S_BETA_ALPHA="${ROLLOUT_TRAIN_S_BETA_ALPHA:-2.0}"
73
+ ROLLOUT_TRAIN_S_BETA_BETA="${ROLLOUT_TRAIN_S_BETA_BETA:-6.0}"
74
+ ROLLOUT_TRAIN_TEMP="${ROLLOUT_TRAIN_TEMP:-1.0}"
75
+ ROLLOUT_TRAIN_MAX_GAMMA="${ROLLOUT_TRAIN_MAX_GAMMA:-1.0}"
76
+ ROLLOUT_TRAIN_CORRUPT_ONLY="${ROLLOUT_TRAIN_CORRUPT_ONLY:-1}"
77
+ ROLLOUT_TRAIN_SAMPLEWISE="${ROLLOUT_TRAIN_SAMPLEWISE:-0}"
78
+ ROLLOUT_TRAIN_SELECTED_ONLY="${ROLLOUT_TRAIN_SELECTED_ONLY:-0}"
79
+ ROLLOUT_TRAIN_COMPUTE_ALWAYS="${ROLLOUT_TRAIN_COMPUTE_ALWAYS:-0}"
80
+ ROLLOUT_TRAIN_SYNC_T="${ROLLOUT_TRAIN_SYNC_T:-0}"
81
+
82
+ TARGET_LOSS="${TARGET_LOSS:-linear_soft_kl}"
83
+ LINEAR_SOFT_TARGET_POWER="${LINEAR_SOFT_TARGET_POWER:-1.0}"
84
+ LINEAR_SOFT_TARGET_MIN_CONF="${LINEAR_SOFT_TARGET_MIN_CONF:-0.0}"
85
+ LINEAR_SOFT_TARGET_MAX_CONF="${LINEAR_SOFT_TARGET_MAX_CONF:-1.0}"
86
+
87
+ tag() {
88
+ printf "%s" "$1" | sed -e 's/,/_/g' -e 's/-/m/g' -e 's/\./p/g'
89
+ }
90
+
91
+ RUN_NAME="${RUN_NAME:-pilot_gpt2_softendpoint_mn_n${MAX_LEN}_m$(tag "${LOWK_CLEAN_TOKENS}")_${CLEAN_STATE_MODE}_${TOTAL_STEPS}_${STAMP}}"
92
+ SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
93
+ LOG_FILE="${LOG_FILE:-logs/softendpoint_mn_pilot_4gpu/${RUN_NAME}.log}"
94
+ RESUME_PATH="${RESUME_PATH:-}"
95
+ INIT_MODEL_PATH="${INIT_MODEL_PATH:-}"
96
+ mkdir -p "$(dirname "${LOG_FILE}")" "${SAVE_DIR}"
97
+
98
+ echo "[launch] gpt2 cached OWT soft-endpoint m/n pilot"
99
+ echo "[launch] run_name=${RUN_NAME}"
100
+ echo "[launch] save_dir=${SAVE_DIR}"
101
+ echo "[launch] n=${MAX_LEN} m=${LOWK_CLEAN_TOKENS} clean_state_mode=${CLEAN_STATE_MODE}"
102
+ echo "[launch] mask_mixture lowk=${MASK_MIXTURE_LOWK_PROB} all=${MASK_MIXTURE_ALL_PROB}"
103
+ echo "[launch] model d=${D_MODEL} layers=${N_LAYERS} heads=${N_HEADS} ff=${DIM_FF} vocab_override=${VOCAB_SIZE_OVERRIDE}"
104
+ echo "[launch] optimizer=muon muon_impl=${MUON_IMPL} weight_decay=${WEIGHT_DECAY:-0.1} output_weight_decay=${OUTPUT_WEIGHT_DECAY}"
105
+ echo "[launch] target_loss=${TARGET_LOSS} conf=${LINEAR_SOFT_TARGET_MIN_CONF}->${LINEAR_SOFT_TARGET_MAX_CONF} power=${LINEAR_SOFT_TARGET_POWER}"
106
+ echo "[launch] mask_ratio=${MIN_MASK_RATIO}->${MAX_MASK_RATIO}"
107
+ echo "[launch] mask_ratio_floor_schedule=${MASK_RATIO_FLOOR_SCHEDULE}"
108
+ echo "[launch] dirichlet C=${DIRICHLET_CONCENTRATION_MIN}->${DIRICHLET_CONCENTRATION_MAX} endpoint=${DIRICHLET_ENDPOINT_MODE} sampler=${SIMPLEX_BRIDGE_SAMPLER}"
109
+ echo "[launch] wrong_mix seq_alpha=${ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA} wrong_floor=${CATEGORICAL_WRONG_PROB_FLOOR} unigram=${CATEGORICAL_WRONG_UNIGRAM_PROB} uniform=${CATEGORICAL_WRONG_UNIFORM_PROB} basin=${CATEGORICAL_WRONG_BASIN_PROB} basin_ids=${CATEGORICAL_WRONG_BASIN_TOKEN_IDS}"
110
+ echo "[launch] rollout_train prob=${ROLLOUT_TRAIN_PROB} mode=${ROLLOUT_TRAIN_TIME_MODE} steps=${ROLLOUT_TRAIN_STEPS} steps_min=${ROLLOUT_TRAIN_STEPS_MIN} infer_steps=${ROLLOUT_TRAIN_INFER_STEPS} s_dist=${ROLLOUT_TRAIN_S_DIST} s_frac=${ROLLOUT_TRAIN_S_MIN_FRAC}->${ROLLOUT_TRAIN_S_MAX_FRAC} temp=${ROLLOUT_TRAIN_TEMP} corrupt_only=${ROLLOUT_TRAIN_CORRUPT_ONLY} samplewise=${ROLLOUT_TRAIN_SAMPLEWISE} selected_only=${ROLLOUT_TRAIN_SELECTED_ONLY} sync_t=${ROLLOUT_TRAIN_SYNC_T}"
111
+ echo "[launch] cache=${OWT_CHUNK_CACHE_DIR} exact_repeat_per_chunk=${OWT_EXACT_REPEAT_PER_CHUNK}"
112
+ if [[ -n "${RESUME_PATH}" ]]; then
113
+ echo "[launch] resume_path=${RESUME_PATH}"
114
+ fi
115
+ if [[ -n "${INIT_MODEL_PATH}" ]]; then
116
+ echo "[launch] init_model_path=${INIT_MODEL_PATH}"
117
+ fi
118
+
119
+ resume_args=()
120
+ if [[ -n "${RESUME_PATH}" ]]; then
121
+ resume_args=(--resume_path "${RESUME_PATH}")
122
+ fi
123
+ init_args=()
124
+ if [[ -n "${INIT_MODEL_PATH}" ]]; then
125
+ init_args=(--init_model_path "${INIT_MODEL_PATH}")
126
+ fi
127
+ rollout_bool_args=()
128
+ if [[ "${ROLLOUT_TRAIN_CORRUPT_ONLY}" == "0" || "${ROLLOUT_TRAIN_CORRUPT_ONLY}" == "false" || "${ROLLOUT_TRAIN_CORRUPT_ONLY}" == "False" ]]; then
129
+ rollout_bool_args+=(--no-rollout_train_corrupt_only)
130
+ else
131
+ rollout_bool_args+=(--rollout_train_corrupt_only)
132
+ fi
133
+ if [[ "${ROLLOUT_TRAIN_SAMPLEWISE}" == "1" || "${ROLLOUT_TRAIN_SAMPLEWISE}" == "true" || "${ROLLOUT_TRAIN_SAMPLEWISE}" == "True" ]]; then
134
+ rollout_bool_args+=(--rollout_train_samplewise)
135
+ fi
136
+ if [[ "${ROLLOUT_TRAIN_SELECTED_ONLY}" == "1" || "${ROLLOUT_TRAIN_SELECTED_ONLY}" == "true" || "${ROLLOUT_TRAIN_SELECTED_ONLY}" == "True" ]]; then
137
+ rollout_bool_args+=(--rollout_train_selected_only)
138
+ fi
139
+ if [[ "${ROLLOUT_TRAIN_COMPUTE_ALWAYS}" == "1" || "${ROLLOUT_TRAIN_COMPUTE_ALWAYS}" == "true" || "${ROLLOUT_TRAIN_COMPUTE_ALWAYS}" == "True" ]]; then
140
+ rollout_bool_args+=(--rollout_train_compute_always)
141
+ fi
142
+ if [[ "${ROLLOUT_TRAIN_SYNC_T}" == "1" || "${ROLLOUT_TRAIN_SYNC_T}" == "true" || "${ROLLOUT_TRAIN_SYNC_T}" == "True" ]]; then
143
+ rollout_bool_args+=(--rollout_train_sync_t)
144
+ fi
145
+
146
+ python -m torch.distributed.run \
147
+ --nnodes=1 \
148
+ --nproc_per_node="${NPROC_PER_NODE}" \
149
+ --node_rank=0 \
150
+ --master_addr="${MASTER_ADDR}" \
151
+ --master_port="${MASTER_PORT}" \
152
+ train.py \
153
+ --data_path "${DATA_PATH}" \
154
+ --text_column text \
155
+ --openwebtext_split train_minus_100k \
156
+ --detokenizer auto \
157
+ --tokenizer_path "${TOKENIZER_PATH}" \
158
+ --save_dir "${SAVE_DIR}" \
159
+ "${resume_args[@]}" \
160
+ "${init_args[@]}" \
161
+ --wrap \
162
+ --wrap_mode stream \
163
+ --wrap_record_buffer_size 200 \
164
+ --owt_cached_chunks \
165
+ --owt_chunk_cache_dir "${OWT_CHUNK_CACHE_DIR}" \
166
+ --owt_chunk_cache_write_batch 4096 \
167
+ --owt_exact_repeat_per_chunk "${OWT_EXACT_REPEAT_PER_CHUNK}" \
168
+ --max_len "${MAX_LEN}" \
169
+ --batch_size "${PER_GPU_BATCH_SIZE}" \
170
+ --num_workers "${NUM_WORKERS}" \
171
+ --dataloader_prefetch_factor "${DATALOADER_PREFETCH_FACTOR}" \
172
+ --global_batch_size "${GLOBAL_BATCH_SIZE}" \
173
+ --total_steps "${TOTAL_STEPS}" \
174
+ --log_every "${LOG_EVERY:-50}" \
175
+ --eval_every 0 \
176
+ --save_every "${SAVE_EVERY:-500}" \
177
+ --latest_every "${LATEST_EVERY:-250}" \
178
+ --lr "${LEARNING_RATE:-0.002}" \
179
+ --weight_decay "${WEIGHT_DECAY:-0.1}" \
180
+ --output_weight_decay "${OUTPUT_WEIGHT_DECAY}" \
181
+ --adam_beta1 0.9 \
182
+ --adam_beta2 0.95 \
183
+ --adam_eps 1e-8 \
184
+ --optimizer muon \
185
+ --muon_impl "${MUON_IMPL}" \
186
+ --muon_momentum 0.95 \
187
+ --muon_ns_steps 5 \
188
+ --muon_update_scale 1.0 \
189
+ --ema_decay 0.9999 \
190
+ --ema_start_step 0 \
191
+ --warmup_steps "${WARMUP_STEPS:-100}" \
192
+ --lr_schedule constant_warmup \
193
+ --min_lr 0 \
194
+ --grad_clip 1.0 \
195
+ --adamw_param_groups nanogpt \
196
+ --seed 123 \
197
+ --d_model "${D_MODEL}" \
198
+ --cond_dim "${COND_DIM}" \
199
+ --n_layers "${N_LAYERS}" \
200
+ --n_heads "${N_HEADS}" \
201
+ --dim_ff "${DIM_FF}" \
202
+ --vocab_size_override "${VOCAB_SIZE_OVERRIDE}" \
203
+ --dropout 0.0 \
204
+ --model_type ddit \
205
+ --state_format prob \
206
+ --bridge dirichlet \
207
+ --target_loss "${TARGET_LOSS}" \
208
+ --linear_soft_target_power "${LINEAR_SOFT_TARGET_POWER}" \
209
+ --linear_soft_target_min_conf "${LINEAR_SOFT_TARGET_MIN_CONF}" \
210
+ --linear_soft_target_max_conf "${LINEAR_SOFT_TARGET_MAX_CONF}" \
211
+ --rollout_train_prob "${ROLLOUT_TRAIN_PROB}" \
212
+ --rollout_train_steps "${ROLLOUT_TRAIN_STEPS}" \
213
+ --rollout_train_steps_min "${ROLLOUT_TRAIN_STEPS_MIN}" \
214
+ --rollout_train_infer_steps "${ROLLOUT_TRAIN_INFER_STEPS}" \
215
+ --rollout_train_time_mode "${ROLLOUT_TRAIN_TIME_MODE}" \
216
+ --rollout_train_s_dist "${ROLLOUT_TRAIN_S_DIST}" \
217
+ --rollout_train_s_min_frac "${ROLLOUT_TRAIN_S_MIN_FRAC}" \
218
+ --rollout_train_s_max_frac "${ROLLOUT_TRAIN_S_MAX_FRAC}" \
219
+ --rollout_train_s_beta_alpha "${ROLLOUT_TRAIN_S_BETA_ALPHA}" \
220
+ --rollout_train_s_beta_beta "${ROLLOUT_TRAIN_S_BETA_BETA}" \
221
+ --rollout_train_temp "${ROLLOUT_TRAIN_TEMP}" \
222
+ --rollout_train_max_gamma "${ROLLOUT_TRAIN_MAX_GAMMA}" \
223
+ "${rollout_bool_args[@]}" \
224
+ --target_prob 1.0 \
225
+ --min_t 0.0 \
226
+ --max_t 1.0 \
227
+ --t_sampling_mode "${T_SAMPLING_MODE:-logit_normal}" \
228
+ --t_sampling_logit_mean "${T_LOGIT_MEAN:--1.5}" \
229
+ --t_sampling_logit_std "${T_LOGIT_STD:-0.8}" \
230
+ --t_sampling_eps 1e-4 \
231
+ --dual_t \
232
+ --corrupt_t_mode same \
233
+ --corrupt_min_t 0.0 \
234
+ --corrupt_max_t 1.0 \
235
+ --min_mask_ratio "${MIN_MASK_RATIO}" \
236
+ --max_mask_ratio "${MAX_MASK_RATIO}" \
237
+ --mask_ratio_floor_schedule "${MASK_RATIO_FLOOR_SCHEDULE}" \
238
+ --mask_mixture_lowk_prob "${MASK_MIXTURE_LOWK_PROB}" \
239
+ --mask_mixture_all_prob "${MASK_MIXTURE_ALL_PROB}" \
240
+ --mask_mixture_lowk_clean_tokens "${LOWK_CLEAN_TOKENS}" \
241
+ --clean_state_mode "${CLEAN_STATE_MODE}" \
242
+ --wrong_token_replace_prob 1.0 \
243
+ --wrong_token_schedule linear_t \
244
+ --wrong_token_exp_k 1.0 \
245
+ --dirichlet_concentration_min "${DIRICHLET_CONCENTRATION_MIN}" \
246
+ --dirichlet_concentration_max "${DIRICHLET_CONCENTRATION_MAX}" \
247
+ --dirichlet_endpoint_mode "${DIRICHLET_ENDPOINT_MODE}" \
248
+ --dirichlet_semantic_t_mode "${DIRICHLET_SEMANTIC_T_MODE}" \
249
+ --dirichlet_semantic_t_value "${DIRICHLET_SEMANTIC_T_VALUE}" \
250
+ --endpoint_sequence_random_prob_alpha "${ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA}" \
251
+ --categorical_wrong_basin_token_ids "${CATEGORICAL_WRONG_BASIN_TOKEN_IDS}" \
252
+ --categorical_wrong_basin_prob "${CATEGORICAL_WRONG_BASIN_PROB}" \
253
+ --categorical_wrong_unigram_prob "${CATEGORICAL_WRONG_UNIGRAM_PROB}" \
254
+ --categorical_wrong_uniform_prob "${CATEGORICAL_WRONG_UNIFORM_PROB}" \
255
+ --categorical_wrong_prob_floor "${CATEGORICAL_WRONG_PROB_FLOOR}" \
256
+ --categorical_wrong_basin_shared_prob "${CATEGORICAL_WRONG_BASIN_SHARED_PROB}" \
257
+ --categorical_wrong_unigram_shared_prob "${CATEGORICAL_WRONG_UNIGRAM_SHARED_PROB}" \
258
+ --categorical_wrong_from_full_vocab \
259
+ --simplex_bridge_sampler "${SIMPLEX_BRIDGE_SAMPLER}" \
260
+ --logistic_normal_sigma_min "${LOGISTIC_NORMAL_SIGMA_MIN}" \
261
+ --logistic_normal_sigma_max "${LOGISTIC_NORMAL_SIGMA_MAX}" \
262
+ --logistic_normal_tau_min "${LOGISTIC_NORMAL_TAU_MIN}" \
263
+ --logistic_normal_tau_max "${LOGISTIC_NORMAL_TAU_MAX}" \
264
+ --eps 1e-8 \
265
+ --noise_init logistic_normal \
266
+ --bridge_noise_init logistic_normal \
267
+ --noise_sigma -1 \
268
+ --allow_tf32 \
269
+ --ddp_gradient_as_bucket_view \
270
+ --bf16 \
271
+ 2>&1 | tee -a "${LOG_FILE}"
LTA_openwebtext_dualt/scripts/launch_lta_owt_selfcond_p05_4gpu.sh ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}"
7
+ export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
8
+ export TOKENIZERS_PARALLELISM=false
9
+
10
+ STAMP="$(date +%Y%m%d_%H%M%S)"
11
+ export RUN_NAME="${RUN_NAME:-lta_owt_gpt2cached_len1024_selfcond_p05_rollout1_samplewise_c1024_ddit768x12_muon_ema_gbs512_4gpu_50k_${STAMP}}"
12
+ export SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
13
+ export LOG_FILE="${LOG_FILE:-logs/selfcond_4gpu/${RUN_NAME}.log}"
14
+ mkdir -p "$(dirname "${LOG_FILE}")"
15
+
16
+ export NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
17
+ export MASTER_PORT="${MASTER_PORT:-31994}"
18
+ export PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
19
+ export GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
20
+ export TOTAL_STEPS="${TOTAL_STEPS:-50000}"
21
+ export WARMUP_STEPS="${WARMUP_STEPS:-2000}"
22
+ export MAX_LEN="${MAX_LEN:-1024}"
23
+ export NUM_WORKERS="${NUM_WORKERS:-4}"
24
+ export DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
25
+ export LOG_EVERY="${LOG_EVERY:-20}"
26
+ export SAVE_EVERY="${SAVE_EVERY:-1000}"
27
+ export LATEST_EVERY="${LATEST_EVERY:-500}"
28
+ export EVAL_EVERY="${EVAL_EVERY:-0}"
29
+
30
+ export OWT_CACHED_CHUNKS=1
31
+ export OWT_CHUNK_CACHE_REBUILD=0
32
+ export OWT_CHUNK_CACHE_DIR="${OWT_CHUNK_CACHE_DIR:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k}"
33
+
34
+ export OPTIMIZER="${OPTIMIZER:-muon}"
35
+ export LEARNING_RATE="${LEARNING_RATE:-0.002}"
36
+ export LR_SCHEDULE="${LR_SCHEDULE:-constant_warmup}"
37
+ export MIN_LR="${MIN_LR:-0}"
38
+ export WEIGHT_DECAY="${WEIGHT_DECAY:-0.0}"
39
+ export ADAM_BETA1="${ADAM_BETA1:-0.9}"
40
+ export ADAM_BETA2="${ADAM_BETA2:-0.95}"
41
+ export ADAM_EPS="${ADAM_EPS:-1e-8}"
42
+ export MUON_MOMENTUM="${MUON_MOMENTUM:-0.95}"
43
+ export MUON_NS_STEPS="${MUON_NS_STEPS:-5}"
44
+ export MUON_UPDATE_SCALE="${MUON_UPDATE_SCALE:-1.0}"
45
+ export EMA_DECAY="${EMA_DECAY:-0.9999}"
46
+ export EMA_START_STEP="${EMA_START_STEP:-0}"
47
+
48
+ export ALLOW_TF32="${ALLOW_TF32:-1}"
49
+ export ACTIVATION_CHECKPOINTING="${ACTIVATION_CHECKPOINTING:-0}"
50
+ export ACTIVATION_CHECKPOINT_INTERVAL="${ACTIVATION_CHECKPOINT_INTERVAL:-2}"
51
+ export DDP_GRADIENT_AS_BUCKET_VIEW="${DDP_GRADIENT_AS_BUCKET_VIEW:-1}"
52
+ export FULL_TRAIN_STATS="${FULL_TRAIN_STATS:-0}"
53
+ export ENABLE_TORCH_COMPILE="${ENABLE_TORCH_COMPILE:-0}"
54
+ export ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
55
+
56
+ export ROLLOUT_TRAIN_PROB="${ROLLOUT_TRAIN_PROB:-0.5}"
57
+ export ROLLOUT_TRAIN_STEPS="${ROLLOUT_TRAIN_STEPS:-1}"
58
+ export ROLLOUT_TRAIN_INFER_STEPS="${ROLLOUT_TRAIN_INFER_STEPS:-64}"
59
+ export ROLLOUT_TRAIN_TEMP="${ROLLOUT_TRAIN_TEMP:-1.45}"
60
+ export ROLLOUT_TRAIN_MAX_GAMMA="${ROLLOUT_TRAIN_MAX_GAMMA:-1.0}"
61
+ export ROLLOUT_TRAIN_CORRUPT_ONLY="${ROLLOUT_TRAIN_CORRUPT_ONLY:-1}"
62
+ export ROLLOUT_TRAIN_SAMPLEWISE="${ROLLOUT_TRAIN_SAMPLEWISE:-1}"
63
+ export ROLLOUT_TRAIN_COMPUTE_ALWAYS="${ROLLOUT_TRAIN_COMPUTE_ALWAYS:-0}"
64
+
65
+ bash scripts/launch_lta_owt_c1024_fullycoupled_8gpu_len1024_gpt2_cached_chunks_1m.sh
LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_blockar_twostream_8gpu.sh ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+ set -x
4
+ set -o pipefail
5
+
6
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
7
+
8
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
9
+ export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
10
+ export MASTER_PORT="${MASTER_PORT:-32331}"
11
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
12
+ export TOKENIZERS_PARALLELISM=false
13
+ export PYTHONUNBUFFERED=1
14
+ export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
15
+
16
+ DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-t5}"
17
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
18
+
19
+ # Two-stream packs a 1024-token example into a 2048-token model sequence, so the
20
+ # dense first version needs a smaller per-GPU batch than the regular 1024 runs.
21
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-128}"
22
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-4}"
23
+ TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
24
+ WARMUP_STEPS="${WARMUP_STEPS:-2000}"
25
+ MAX_RECORDS="${MAX_RECORDS:-0}"
26
+
27
+ D_MODEL="${D_MODEL:-768}"
28
+ N_LAYERS="${N_LAYERS:-12}"
29
+ N_HEADS="${N_HEADS:-12}"
30
+ DIM_FF="${DIM_FF:-3072}"
31
+
32
+ RUN_NAME="${RUN_NAME:-lta_owt_t5_blockar_twostream_b8x128_mask1_uniformt_adaln_swiglu_gbs${GLOBAL_BATCH_SIZE}_${NPROC_PER_NODE}gpu_$(date +%Y%m%d_%H%M%S)}"
33
+
34
+ mkdir -p logs/blockar_twostream runs/${RUN_NAME}
35
+
36
+ torchrun \
37
+ --nnodes=1 \
38
+ --nproc_per_node="${NPROC_PER_NODE}" \
39
+ --master_port="${MASTER_PORT}" \
40
+ train.py \
41
+ --data_path "${DATA_PATH}" \
42
+ --tokenized_hf \
43
+ --tokenized_pad_token pad \
44
+ --tokenizer_path "${TOKENIZER_PATH}" \
45
+ --max_records "${MAX_RECORDS}" \
46
+ --save_dir "runs/${RUN_NAME}" \
47
+ --max_len 1024 \
48
+ --batch_size "${PER_GPU_BATCH_SIZE}" \
49
+ --global_batch_size "${GLOBAL_BATCH_SIZE}" \
50
+ --num_workers 8 \
51
+ --dataloader_prefetch_factor 4 \
52
+ --epochs 0 \
53
+ --total_steps "${TOTAL_STEPS}" \
54
+ --warmup_steps "${WARMUP_STEPS}" \
55
+ --log_every 100 \
56
+ --eval_every 0 \
57
+ --save_every 1000 \
58
+ --latest_every 500 \
59
+ --optimizer adamw \
60
+ --lr 6e-4 \
61
+ --lr_schedule cosine \
62
+ --min_lr 6e-5 \
63
+ --weight_decay 0.1 \
64
+ --output_weight_decay -1 \
65
+ --adamw_param_groups nanogpt \
66
+ --adam_beta1 0.9 \
67
+ --adam_beta2 0.999 \
68
+ --adam_eps 1e-8 \
69
+ --ema_decay 0.9999 \
70
+ --ema_start_step 0 \
71
+ --grad_clip 1.0 \
72
+ --seed 42 \
73
+ --d_model "${D_MODEL}" \
74
+ --cond_dim 128 \
75
+ --n_layers "${N_LAYERS}" \
76
+ --n_heads "${N_HEADS}" \
77
+ --dim_ff "${DIM_FF}" \
78
+ --dropout 0.0 \
79
+ --no-output_bias \
80
+ --norm_type rmsnorm \
81
+ --model_type ddit \
82
+ --ddit_mlp_type swiglu \
83
+ --state_format prob \
84
+ --bridge dirichlet \
85
+ --target_loss hard_ce \
86
+ --loss_t_weight_mode none \
87
+ --loss_t_min_weight 0.0 \
88
+ --block_ar_two_stream \
89
+ --block_ar_block_len 128 \
90
+ --target_prob 1.0 \
91
+ --min_t 0.0 \
92
+ --max_t 1.0 \
93
+ --t_sampling_mode uniform \
94
+ --dual_t \
95
+ --corrupt_t_mode same \
96
+ --corrupt_min_t 0.0 \
97
+ --corrupt_max_t 1.0 \
98
+ --min_mask_ratio 1.0 \
99
+ --max_mask_ratio 1.0 \
100
+ --mask_mixture_all_prob 1.0 \
101
+ --wrong_token_replace_prob 1.0 \
102
+ --wrong_token_schedule linear_t \
103
+ --dirichlet_concentration_min 1.0 \
104
+ --dirichlet_concentration_max 1024 \
105
+ --dirichlet_endpoint_mode categorical_dual_t \
106
+ --categorical_wrong_from_full_vocab \
107
+ --simplex_bridge_sampler dirichlet \
108
+ --eps 1e-8 \
109
+ --infer_steps 1024 \
110
+ --decode_damping 1.0 \
111
+ --max_gamma 1.0 \
112
+ --decode_solver flowmap \
113
+ --noise_init logistic_normal \
114
+ --bridge_noise_init logistic_normal \
115
+ --noise_sigma -1 \
116
+ --allow_tf32 \
117
+ --activation_checkpointing \
118
+ --activation_checkpoint_scope mlp \
119
+ --ddp_gradient_as_bucket_view \
120
+ 2>&1 | tee -a "logs/blockar_twostream/${RUN_NAME}.log"
LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_rollin_p50_randk0_4_8gpu.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
7
+ export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
8
+ export MASTER_PORT="${MASTER_PORT:-32254}"
9
+
10
+ export DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-t5}"
11
+ export TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
12
+
13
+ export GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
14
+ export PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
15
+ export EPOCHS="${EPOCHS:-5}"
16
+ export TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
17
+ export NUM_WORKERS="${NUM_WORKERS:-8}"
18
+ export DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
19
+ export LOG_EVERY="${LOG_EVERY:-100}"
20
+ export LATEST_EVERY="${LATEST_EVERY:-1000}"
21
+ export ALLOW_TF32="${ALLOW_TF32:-1}"
22
+
23
+ export LR="${LR:-0.002}"
24
+ export WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
25
+ export OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}"
26
+ export OUTPUT_INIT_STD="${OUTPUT_INIT_STD:-0.0}"
27
+
28
+ # Corrected roll-in: model-temperature rollout, synced sampled path, random K in {0,1,2,3,4}.
29
+ export T_SAMPLING_MODE="${T_SAMPLING_MODE:-uniform}"
30
+ export ROLLOUT_TRAIN_PROB="${ROLLOUT_TRAIN_PROB:-0.50}"
31
+ export ROLLOUT_TRAIN_TIME_MODE="${ROLLOUT_TRAIN_TIME_MODE:-sampled_path}"
32
+ export ROLLOUT_TRAIN_STEPS="${ROLLOUT_TRAIN_STEPS:-4}"
33
+ export ROLLOUT_TRAIN_STEPS_MIN="${ROLLOUT_TRAIN_STEPS_MIN:-0}"
34
+ export ROLLOUT_TRAIN_INFER_STEPS="${ROLLOUT_TRAIN_INFER_STEPS:-1}"
35
+ export ROLLOUT_TRAIN_S_DIST="${ROLLOUT_TRAIN_S_DIST:-uniform}"
36
+ export ROLLOUT_TRAIN_S_MIN_FRAC="${ROLLOUT_TRAIN_S_MIN_FRAC:-0.0}"
37
+ export ROLLOUT_TRAIN_S_MAX_FRAC="${ROLLOUT_TRAIN_S_MAX_FRAC:-0.25}"
38
+ export ROLLOUT_TRAIN_TEMP="${ROLLOUT_TRAIN_TEMP:-1.0}"
39
+ export ROLLOUT_TRAIN_MAX_GAMMA="${ROLLOUT_TRAIN_MAX_GAMMA:-1.0}"
40
+ export ROLLOUT_TRAIN_CORRUPT_ONLY="${ROLLOUT_TRAIN_CORRUPT_ONLY:-1}"
41
+ export ROLLOUT_TRAIN_SAMPLEWISE="${ROLLOUT_TRAIN_SAMPLEWISE:-1}"
42
+ export ROLLOUT_TRAIN_SELECTED_ONLY="${ROLLOUT_TRAIN_SELECTED_ONLY:-1}"
43
+ export ROLLOUT_TRAIN_COMPUTE_ALWAYS="${ROLLOUT_TRAIN_COMPUTE_ALWAYS:-0}"
44
+ export ROLLOUT_TRAIN_SYNC_T="${ROLLOUT_TRAIN_SYNC_T:-1}"
45
+
46
+ export MIN_MASK_RATIO="${MIN_MASK_RATIO:-1.0}"
47
+ export MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
48
+ export MASK_MIXTURE_ORIGINAL_PROB="${MASK_MIXTURE_ORIGINAL_PROB:-0.0}"
49
+ export MASK_MIXTURE_LOWK_PROB="${MASK_MIXTURE_LOWK_PROB:-0.0}"
50
+ export MASK_MIXTURE_LOWCORRUPT_PROB="${MASK_MIXTURE_LOWCORRUPT_PROB:-0.0}"
51
+ export MASK_MIXTURE_BLOCK_PROB="${MASK_MIXTURE_BLOCK_PROB:-0.0}"
52
+ export MASK_MIXTURE_ALL_PROB="${MASK_MIXTURE_ALL_PROB:-1.0}"
53
+
54
+ export TARGET_LOSS="${TARGET_LOSS:-hard_ce}"
55
+ export T_LOGIT_MEAN="${T_LOGIT_MEAN:--1.5}"
56
+ export T_LOGIT_STD="${T_LOGIT_STD:-0.8}"
57
+
58
+ export RUN_NAME="${RUN_NAME:-lta_owt_t5_rollin_p50_randk0_4_uniformt_temp1_synct_mask1_gbs${GLOBAL_BATCH_SIZE}_${NPROC_PER_NODE}gpu_$(date +%Y%m%d_%H%M%S)}"
59
+
60
+ bash scripts/launch_lta_owt_elfaligned_t5_logitnormal_tokenized_8gpu.sh
LTA_openwebtext_dualt/scripts/make_compact_gpt2_bpe_tokenizer.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Iterable
8
+
9
+
10
+ SPECIAL_TOKENS = ["<pad>", "</s>", "<unk>", "<|endoftext|>"]
11
+
12
+
13
+ def parse_args() -> argparse.Namespace:
14
+ p = argparse.ArgumentParser(
15
+ description=(
16
+ "Create compact byte-level GPT-2 BPE tokenizers by truncating the "
17
+ "merge table while preserving byte fallback coverage."
18
+ )
19
+ )
20
+ p.add_argument("--source_tokenizer", required=True)
21
+ p.add_argument("--output_dir", required=True)
22
+ p.add_argument("--vocab_size", type=int, required=True)
23
+ p.add_argument(
24
+ "--special_tokens",
25
+ nargs="*",
26
+ default=SPECIAL_TOKENS,
27
+ help="Special tokens to reserve first in the compact vocab.",
28
+ )
29
+ return p.parse_args()
30
+
31
+
32
+ def normalize_merge(merge: str | list[str]) -> tuple[str, str]:
33
+ if isinstance(merge, str):
34
+ left, right = merge.split(" ", 1)
35
+ return left, right
36
+ if isinstance(merge, list) and len(merge) == 2:
37
+ return str(merge[0]), str(merge[1])
38
+ raise ValueError(f"Unsupported merge entry: {merge!r}")
39
+
40
+
41
+ def merge_to_json(left: str, right: str, like: str | list[str]) -> str | list[str]:
42
+ if isinstance(like, str):
43
+ return f"{left} {right}"
44
+ return [left, right]
45
+
46
+
47
+ def base_byte_tokens(orig_id_to_token: dict[int, str]) -> Iterable[str]:
48
+ # GPT-2 byte-level BPE starts with exactly 256 byte fallback tokens.
49
+ for idx in range(256):
50
+ token = orig_id_to_token.get(idx)
51
+ if token is None:
52
+ raise ValueError(f"Source tokenizer is missing byte token id={idx}")
53
+ yield token
54
+
55
+
56
+ def main() -> None:
57
+ args = parse_args()
58
+ src_path = Path(args.source_tokenizer)
59
+ out_dir = Path(args.output_dir)
60
+ out_dir.mkdir(parents=True, exist_ok=True)
61
+
62
+ with src_path.open("r", encoding="utf-8") as f:
63
+ tokenizer = json.load(f)
64
+
65
+ model = tokenizer["model"]
66
+ orig_vocab: dict[str, int] = {str(k): int(v) for k, v in model["vocab"].items()}
67
+ orig_id_to_token = {idx: tok for tok, idx in orig_vocab.items()}
68
+ orig_merges = list(model.get("merges", []))
69
+
70
+ special_tokens = []
71
+ for tok in args.special_tokens:
72
+ if tok not in special_tokens:
73
+ special_tokens.append(tok)
74
+
75
+ if args.vocab_size <= len(special_tokens) + 256:
76
+ raise ValueError(
77
+ f"vocab_size={args.vocab_size} must exceed specials+byte alphabet "
78
+ f"({len(special_tokens)}+256)"
79
+ )
80
+
81
+ ordered_tokens: list[str] = []
82
+ token_set: set[str] = set()
83
+
84
+ def add_token(tok: str) -> None:
85
+ if tok not in token_set:
86
+ token_set.add(tok)
87
+ ordered_tokens.append(tok)
88
+
89
+ for tok in special_tokens:
90
+ add_token(tok)
91
+ for tok in base_byte_tokens(orig_id_to_token):
92
+ add_token(tok)
93
+
94
+ compact_merges: list[str | list[str]] = []
95
+ for merge in orig_merges:
96
+ left, right = normalize_merge(merge)
97
+ out = left + right
98
+ if left not in token_set or right not in token_set:
99
+ continue
100
+ if out not in orig_vocab:
101
+ continue
102
+ if out not in token_set:
103
+ if len(ordered_tokens) >= args.vocab_size:
104
+ break
105
+ add_token(out)
106
+ if out in token_set:
107
+ compact_merges.append(merge_to_json(left, right, merge))
108
+ if len(ordered_tokens) >= args.vocab_size:
109
+ break
110
+
111
+ if len(ordered_tokens) != args.vocab_size:
112
+ raise RuntimeError(
113
+ f"Could only build vocab of size {len(ordered_tokens)}; "
114
+ f"requested {args.vocab_size}"
115
+ )
116
+
117
+ compact_vocab = {tok: idx for idx, tok in enumerate(ordered_tokens)}
118
+
119
+ out = dict(tokenizer)
120
+ out["truncation"] = None
121
+ out["padding"] = None
122
+ out["added_tokens"] = [
123
+ {
124
+ "id": compact_vocab[tok],
125
+ "special": True,
126
+ "content": tok,
127
+ "single_word": False,
128
+ "lstrip": False,
129
+ "rstrip": False,
130
+ "normalized": True,
131
+ }
132
+ for tok in special_tokens
133
+ ]
134
+ out["post_processor"] = None
135
+ out_model = dict(model)
136
+ out_model["vocab"] = compact_vocab
137
+ out_model["merges"] = compact_merges
138
+ out_model["unk_token"] = "<unk>"
139
+ out["model"] = out_model
140
+
141
+ with (out_dir / "tokenizer.json").open("w", encoding="utf-8") as f:
142
+ json.dump(out, f, ensure_ascii=False)
143
+
144
+ meta = {
145
+ "source_tokenizer": str(src_path),
146
+ "vocab_size": args.vocab_size,
147
+ "actual_vocab_size": len(compact_vocab),
148
+ "num_merges": len(compact_merges),
149
+ "special_tokens": special_tokens,
150
+ "construction": "compact_gpt2_bpe_prefix_merges",
151
+ }
152
+ with (out_dir / "tokenizer_meta.json").open("w", encoding="utf-8") as f:
153
+ json.dump(meta, f, indent=2, ensure_ascii=False)
154
+ f.write("\n")
155
+
156
+ print(json.dumps(meta, ensure_ascii=False))
157
+
158
+
159
+ if __name__ == "__main__":
160
+ main()
LTA_openwebtext_dualt/scripts/run_lta_lm1b_linear_simplex_len128_4gpu_save1k_with_watch.sh ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}"
7
+ export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
8
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
9
+ export TOKENIZERS_PARALLELISM=false
10
+ export PYTHONUNBUFFERED=1
11
+ export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
12
+ export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
13
+
14
+ DATA_PATH="${DATA_PATH:-data/lm1b_train_parquet}"
15
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/workspace/imagenet_handoff_20260327/nlp_dts_light/assets/distilbert-base-uncased/tokenizer.json}"
16
+ SCORER="${SCORER:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-large-standard}"
17
+
18
+ NNODES="${NNODES:-1}"
19
+ NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
20
+ NODE_RANK="${NODE_RANK:-0}"
21
+ MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
22
+ MASTER_PORT="${MASTER_PORT:-32931}"
23
+
24
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
25
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-64}"
26
+ TOTAL_STEPS="${TOTAL_STEPS:-20000}"
27
+ WARMUP_STEPS="${WARMUP_STEPS:-2500}"
28
+ MAX_LEN="${MAX_LEN:-128}"
29
+ NUM_WORKERS="${NUM_WORKERS:-0}"
30
+ LOG_EVERY="${LOG_EVERY:-100}"
31
+ EVAL_EVERY="${EVAL_EVERY:-0}"
32
+ SAVE_EVERY="${SAVE_EVERY:-1000}"
33
+ LATEST_EVERY="${LATEST_EVERY:-1000}"
34
+
35
+ LR="${LR:-3e-4}"
36
+ D_MODEL="${D_MODEL:-768}"
37
+ COND_DIM="${COND_DIM:-128}"
38
+ N_LAYERS="${N_LAYERS:-12}"
39
+ N_HEADS="${N_HEADS:-12}"
40
+ DIM_FF="${DIM_FF:-3072}"
41
+ DROPOUT="${DROPOUT:-0.1}"
42
+
43
+ LINEAR_NOISE_SIGMA="${LINEAR_NOISE_SIGMA:-3.0}"
44
+ WRONG_TOKEN_REPLACE_PROB="${WRONG_TOKEN_REPLACE_PROB:-0.0}"
45
+ WRONG_TOKEN_SCHEDULE="${WRONG_TOKEN_SCHEDULE:-constant}"
46
+ ENDPOINT_TEMP_START="${ENDPOINT_TEMP_START:-1.45}"
47
+ ENDPOINT_TEMP_END="${ENDPOINT_TEMP_END:-0.8}"
48
+ ENDPOINT_PROJECTION="${ENDPOINT_PROJECTION:-soft}"
49
+ FINAL_FROM="${FINAL_FROM:-model_t1}"
50
+
51
+ ENABLE_WATCHER="${ENABLE_WATCHER:-1}"
52
+ WATCH_CUDA_VISIBLE_DEVICES="${WATCH_CUDA_VISIBLE_DEVICES:-0}"
53
+ WATCH_SAMPLES="${WATCH_SAMPLES:-64}"
54
+ WATCH_STEPS="${WATCH_STEPS:-128}"
55
+ WATCH_POLL_SECONDS="${WATCH_POLL_SECONDS:-30}"
56
+
57
+ RESUME_PATH="${RESUME_PATH:-}"
58
+ ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
59
+
60
+ if [[ "${DATA_PATH}" == *"lm1b_train_parquet"* && "${NUM_WORKERS}" != "0" ]]; then
61
+ echo "LM1B has only 9 parquet shards; forcing NUM_WORKERS=0 to avoid empty DDP dataloader shards." >&2
62
+ NUM_WORKERS=0
63
+ fi
64
+ if [[ ! -e "${DATA_PATH}" ]]; then
65
+ echo "Missing DATA_PATH: ${DATA_PATH}" >&2
66
+ exit 2
67
+ fi
68
+ if [[ ! -f "${TOKENIZER_PATH}" ]]; then
69
+ echo "Missing TOKENIZER_PATH: ${TOKENIZER_PATH}" >&2
70
+ exit 2
71
+ fi
72
+
73
+ RUN_NAME="${RUN_NAME:-lta_lm1b_linear_simplex_len${MAX_LEN}_sigma${LINEAR_NOISE_SIGMA}_gbs${GLOBAL_BATCH_SIZE}_4gpu_20k_save1k_$(date +%Y%m%d_%H%M%S)}"
74
+ SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
75
+ LOG_DIR="${LOG_DIR:-logs/lm1b_linear_simplex_len128_4gpu}"
76
+ LOG_FILE="${LOG_FILE:-${LOG_DIR}/${RUN_NAME}.log}"
77
+ WATCH_LOG="${WATCH_LOG:-${LOG_DIR}/${RUN_NAME}.watch.log}"
78
+
79
+ if [[ -f "${SAVE_DIR}/args.json" && -z "${RESUME_PATH}" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
80
+ echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
81
+ exit 2
82
+ fi
83
+
84
+ mkdir -p "${LOG_DIR}" runs "${SAVE_DIR}"
85
+
86
+ RESUME_ARGS=()
87
+ if [[ -n "${RESUME_PATH}" ]]; then
88
+ RESUME_ARGS+=(--resume_path "${RESUME_PATH}")
89
+ fi
90
+
91
+ echo "[launch] method=lm1b_linear_simplex_algebraic_len128_4gpu"
92
+ echo "[launch] host=$(hostname) time=$(date -Iseconds)"
93
+ echo "[launch] run_name=${RUN_NAME}"
94
+ echo "[launch] save_dir=${SAVE_DIR}"
95
+ echo "[launch] bridge=prob linear simplex: p_t=(1-t)*softmax(${LINEAR_NOISE_SIGMA}*N)+t*onehot(y)"
96
+ echo "[launch] wrong_endpoint: prob=${WRONG_TOKEN_REPLACE_PROB} schedule=${WRONG_TOKEN_SCHEDULE}"
97
+ echo "[launch] infer=linear algebraic: p_next=(1-s_next)*p0+s_next*a_theta"
98
+ echo "[launch] watcher=${ENABLE_WATCHER} watch_gpu=${WATCH_CUDA_VISIBLE_DEVICES} watch_samples=${WATCH_SAMPLES}"
99
+
100
+ run_watcher() {
101
+ local last_done=0
102
+ while true; do
103
+ for ckpt in "${SAVE_DIR}"/step_*.pt; do
104
+ [[ -f "${ckpt}" ]] || continue
105
+ local base step out_dir
106
+ base="$(basename "${ckpt}")"
107
+ step="${base#step_}"
108
+ step="${step%.pt}"
109
+ step=$((10#${step}))
110
+ if (( step <= last_done )); then
111
+ continue
112
+ fi
113
+ out_dir="docs/lta_samples/metrics_$(date +%Y%m%d)/${RUN_NAME}_step${step}_linear_alg_t${ENDPOINT_TEMP_START}_to_${ENDPOINT_TEMP_END}_n${WATCH_SAMPLES}"
114
+ echo "[watch] infer step=${step} ckpt=${ckpt} out=${out_dir}" | tee -a "${WATCH_LOG}"
115
+ CUDA_VISIBLE_DEVICES="${WATCH_CUDA_VISIBLE_DEVICES}" python scripts/eval_lm1b_linear_simplex_genppl.py \
116
+ --checkpoint "${ckpt}" \
117
+ --tokenizer_path "${TOKENIZER_PATH}" \
118
+ --scorer "${SCORER}" \
119
+ --out_dir "${out_dir}" \
120
+ --n_samples "${WATCH_SAMPLES}" \
121
+ --max_len "${MAX_LEN}" \
122
+ --steps "${WATCH_STEPS}" \
123
+ --batch_size 1 \
124
+ --score_batch 1 \
125
+ --noise_init logistic_normal \
126
+ --noise_sigma "${LINEAR_NOISE_SIGMA}" \
127
+ --endpoint_temp_start "${ENDPOINT_TEMP_START}" \
128
+ --endpoint_temp_end "${ENDPOINT_TEMP_END}" \
129
+ --endpoint_projection "${ENDPOINT_PROJECTION}" \
130
+ --endpoint_top_p 1.0 \
131
+ --final_from "${FINAL_FROM}" \
132
+ --seed "$((20260524 + step))" 2>&1 | tee -a "${WATCH_LOG}" || true
133
+ last_done="${step}"
134
+ done
135
+ if (( last_done >= TOTAL_STEPS )); then
136
+ break
137
+ fi
138
+ sleep "${WATCH_POLL_SECONDS}"
139
+ done
140
+ }
141
+
142
+ if [[ "${ENABLE_WATCHER}" == "1" ]]; then
143
+ run_watcher &
144
+ WATCH_PID=$!
145
+ echo "[launch] watcher_pid=${WATCH_PID}"
146
+ fi
147
+
148
+ python -m torch.distributed.run \
149
+ --nnodes="${NNODES}" \
150
+ --nproc_per_node="${NPROC_PER_NODE}" \
151
+ --node_rank="${NODE_RANK}" \
152
+ --master_addr="${MASTER_ADDR}" \
153
+ --master_port="${MASTER_PORT}" \
154
+ train.py \
155
+ --data_path "${DATA_PATH}" \
156
+ --openwebtext_split all \
157
+ --detokenizer auto \
158
+ --tokenizer_path "${TOKENIZER_PATH}" \
159
+ --save_dir "${SAVE_DIR}" \
160
+ --wrap \
161
+ --wrap_mode stream \
162
+ --wrap_record_buffer_size 200 \
163
+ --max_len "${MAX_LEN}" \
164
+ --batch_size "${PER_GPU_BATCH_SIZE}" \
165
+ --num_workers "${NUM_WORKERS}" \
166
+ --global_batch_size "${GLOBAL_BATCH_SIZE}" \
167
+ --total_steps "${TOTAL_STEPS}" \
168
+ --log_every "${LOG_EVERY}" \
169
+ --eval_every "${EVAL_EVERY}" \
170
+ --save_every "${SAVE_EVERY}" \
171
+ --latest_every "${LATEST_EVERY}" \
172
+ --lr "${LR}" \
173
+ --weight_decay 0 \
174
+ --adam_beta1 0.9 \
175
+ --adam_beta2 0.999 \
176
+ --adam_eps 1e-8 \
177
+ --warmup_steps "${WARMUP_STEPS}" \
178
+ --lr_schedule constant_warmup \
179
+ --grad_clip 1.0 \
180
+ --seed 123 \
181
+ --d_model "${D_MODEL}" \
182
+ --cond_dim "${COND_DIM}" \
183
+ --n_layers "${N_LAYERS}" \
184
+ --n_heads "${N_HEADS}" \
185
+ --dim_ff "${DIM_FF}" \
186
+ --dropout "${DROPOUT}" \
187
+ --model_type ddit \
188
+ --state_format prob \
189
+ --bridge prob \
190
+ --target_loss hard_ce \
191
+ --target_prob 1.0 \
192
+ --min_t 0.0 \
193
+ --max_t 1.0 \
194
+ --t_sampling_mode uniform \
195
+ --dual_t \
196
+ --corrupt_t_mode same \
197
+ --corrupt_min_t 0.0 \
198
+ --corrupt_max_t 1.0 \
199
+ --min_mask_ratio 1.0 \
200
+ --max_mask_ratio 1.0 \
201
+ --mask_mixture_all_prob 1.0 \
202
+ --wrong_token_replace_prob "${WRONG_TOKEN_REPLACE_PROB}" \
203
+ --wrong_token_schedule "${WRONG_TOKEN_SCHEDULE}" \
204
+ --eps 1e-8 \
205
+ --infer_steps 128 \
206
+ --decode_damping 1.0 \
207
+ --max_gamma 1.0 \
208
+ --decode_solver flowmap \
209
+ --noise_init logistic_normal \
210
+ --bridge_noise_init logistic_normal \
211
+ --noise_sigma "${LINEAR_NOISE_SIGMA}" \
212
+ "${RESUME_ARGS[@]}" \
213
+ --bf16 2>&1 | tee -a "${LOG_FILE}"
214
+
215
+ if [[ "${ENABLE_WATCHER}" == "1" ]]; then
216
+ wait "${WATCH_PID}" || true
217
+ fi
LTA_openwebtext_dualt/scripts/run_lta_owt_t5elf_absrope_time4_dirichlet_len1025_C1_to_1024_8gpu_1m_mask1_sameT_prebos_save10k.sh ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ # OWT through the existing ELF/T5 tokenized pipeline:
7
+ # [BOS] + up to 1024 existing T5 tokens.
8
+ #
9
+ # The ELF T5 records already end with </s>, so we only prepend BOS here.
10
+ # T5 has no separate BOS token in this tokenizer; the repo wrapper falls back
11
+ # to bos_id = eos_id = </s>.
12
+ #
13
+ # Backbone:
14
+ # ddit_elf = RMSNorm/SwiGLU/QK-norm + RoPE + 4 prefix time tokens.
15
+ # Learned absolute position embeddings are added before RoPE.
16
+ #
17
+ # Bridge:
18
+ # Dirichlet C=1->1024, mask_ratio=1.0, model t and corruption t are shared.
19
+
20
+ export DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-t5}"
21
+ export TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
22
+ export TOKENIZED_HF=1
23
+ export TOKENIZED_PAD_TOKEN="${TOKENIZED_PAD_TOKEN:-pad}"
24
+ export TOKENIZED_PREPEND_BOS=1
25
+ export TOKENIZED_APPEND_EOS=0
26
+ export TOKENIZED_STRIP_EDGE_SPECIALS=0
27
+
28
+ export MAX_LEN="${MAX_LEN:-1025}"
29
+ export VOCAB_SIZE="${VOCAB_SIZE:-32100}"
30
+ export CMIN="${CMIN:-1}"
31
+ export CMAX="${CMAX:-1024}"
32
+
33
+ export MODEL_TYPE=ddit_elf
34
+ export ELF_NUM_TIME_TOKENS="${ELF_NUM_TIME_TOKENS:-4}"
35
+ export ELF_NUM_MODEL_MODE_TOKENS="${ELF_NUM_MODEL_MODE_TOKENS:-0}"
36
+ export QK_NORM="${QK_NORM:-1}"
37
+ export ABS_POS_EMBED=1
38
+ export CORRUPT_T_MODE=same
39
+ export MIN_MASK_RATIO=1.0
40
+ export MAX_MASK_RATIO=1.0
41
+ export CATEGORICAL_WRONG_PROB_FLOOR="${CATEGORICAL_WRONG_PROB_FLOOR:-0.0}"
42
+
43
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
44
+ export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
45
+ export GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
46
+ export PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
47
+ export TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
48
+ export WARMUP_STEPS="${WARMUP_STEPS:-2500}"
49
+ export SAVE_EVERY="${SAVE_EVERY:-10000}"
50
+ export LATEST_EVERY="${LATEST_EVERY:-1000}"
51
+ export LOG_EVERY="${LOG_EVERY:-100}"
52
+
53
+ export DATE_TAG="${DATE_TAG:-$(date +%Y%m%d)}"
54
+ export RUN_NAME="${RUN_NAME:-lta_owt_t5elf_absrope_time4_dirichlet_len1025_C1_to_1024_prebos_mask1_sameT_gbs512_b32_8gpu_1m_save10k_${DATE_TAG}}"
55
+
56
+ export WATCH_ENABLED="${WATCH_ENABLED:-1}"
57
+ export WATCH_STEP_INTERVAL="${WATCH_STEP_INTERVAL:-10000}"
58
+ export WATCH_N_SAMPLES="${WATCH_N_SAMPLES:-128}"
59
+ export WATCH_CUDA_VISIBLE_DEVICES="${WATCH_CUDA_VISIBLE_DEVICES:-7}"
60
+ export WATCH_DECODE_MODE="${WATCH_DECODE_MODE:-dual_line_probe}"
61
+ export WATCH_DECODE_BATCH="${WATCH_DECODE_BATCH:-1}"
62
+ export WATCH_DUAL_SEMANTIC_POWER="${WATCH_DUAL_SEMANTIC_POWER:-1.5}"
63
+ export WATCH_DUAL_EARLY_TEMP="${WATCH_DUAL_EARLY_TEMP:-2.8}"
64
+ export WATCH_DUAL_LATE_TEMP="${WATCH_DUAL_LATE_TEMP:-1.45}"
65
+ export WATCH_DUAL_TEMP_END="${WATCH_DUAL_TEMP_END:-0.55}"
66
+ export WATCH_DUAL_TEMP_POWER="${WATCH_DUAL_TEMP_POWER:-1.5}"
67
+ export WATCH_ENDPOINT_TEMP="${WATCH_ENDPOINT_TEMP:-1.45}"
68
+ export WATCH_ENDPOINT_TOP_P="${WATCH_ENDPOINT_TOP_P:-0.95}"
69
+ export WATCH_GUMBEL_TAU_START="${WATCH_GUMBEL_TAU_START:-1.0}"
70
+ export WATCH_GUMBEL_TAU_END="${WATCH_GUMBEL_TAU_END:-0.2}"
71
+ export WATCH_OUT_BASE="${WATCH_OUT_BASE:-docs/lta_samples/metrics_${DATE_TAG}/owt_t5elf_absrope_time4_len1025_C1_to_1024_prebos_mask1_sameT_dualline_dirres_c${CMIN}_${CMAX}_n${WATCH_N_SAMPLES}/${RUN_NAME}}"
72
+ export WATCH_LOG_DIR="${WATCH_LOG_DIR:-logs/owt_t5elf_absrope_time4_len1025_C1_to_1024_prebos_mask1_sameT_dualline_watch}"
73
+
74
+ bash scripts/run_lta_owt_dirichlet_len1024_Cv_to_2v_8gpu_save1k_with_gumbel_watch.sh
LTA_openwebtext_dualt/scripts/run_train8_ctx1024_rollin_sweep_4gpu.sh ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
7
+ export TOKENIZERS_PARALLELISM=false
8
+ export PYTHONUNBUFFERED=1
9
+
10
+ BASE_CACHE="${BASE_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks}"
11
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
12
+ MAX_LEN="${MAX_LEN:-1024}"
13
+ N_SAMPLES="${N_SAMPLES:-64}"
14
+ INFER_STEPS="${INFER_STEPS:-128}"
15
+ STEP_CHUNK="${STEP_CHUNK:-1000}"
16
+ MAX_TOTAL_STEPS="${MAX_TOTAL_STEPS:-12000}"
17
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-128}"
18
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
19
+ LEARNING_RATE="${LEARNING_RATE:-0.002}"
20
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
21
+ GROUP_STAMP="${GROUP_STAMP:-$(date +%Y%m%d_%H%M%S)}"
22
+ WAIT_FOR_RUN="${WAIT_FOR_RUN:-train8_rollin_len1024_rollin_p50_s4_i32_20260517_1855ctx1024bs128}"
23
+ OUT_ROOT="${OUT_ROOT:-docs/lta_samples/metrics_20260517/ctx1024_rollin_sweep_bs512_ode128_${GROUP_STAMP}}"
24
+ DRIVER_LOG="${DRIVER_LOG:-logs/ctx1024_rollin_sweep_4gpu/${GROUP_STAMP}.log}"
25
+ CURVE_CSV="${CURVE_CSV:-${OUT_ROOT}/hit_ratio_curve.csv}"
26
+ mkdir -p "$(dirname "${DRIVER_LOG}")" "${OUT_ROOT}"
27
+
28
+ cache="${BASE_CACHE}/gpt2_len${MAX_LEN}_train8_compact_overfit"
29
+ vocab_size="$(
30
+ python - "$cache" <<'PY'
31
+ import json
32
+ import sys
33
+ from pathlib import Path
34
+ meta = json.loads((Path(sys.argv[1]) / "meta.json").read_text())
35
+ print(int(meta.get("compact_vocab_size", meta.get("vocab_size"))))
36
+ PY
37
+ )"
38
+
39
+ if [[ ! -f "${CURVE_CSV}" ]]; then
40
+ echo "config,run_name,ckpt_step,train_views_seen,train_tokens_seen,token_acc_mean,exact_count,exact_ref_count,exact_ref_hits" > "${CURVE_CSV}"
41
+ fi
42
+
43
+ latest_step() {
44
+ local run_name="$1"
45
+ python - "$run_name" <<'PY'
46
+ import re
47
+ import sys
48
+ from pathlib import Path
49
+ run = Path("runs") / sys.argv[1]
50
+ steps = []
51
+ for path in run.glob("step_*.pt"):
52
+ m = re.search(r"step_(\d+)\.pt$", path.name)
53
+ if m:
54
+ steps.append(int(m.group(1)))
55
+ print(max(steps) if steps else 0)
56
+ PY
57
+ }
58
+
59
+ free_port() {
60
+ python - <<'PY'
61
+ import socket
62
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
63
+ s.bind(("127.0.0.1", 0))
64
+ print(s.getsockname()[1])
65
+ PY
66
+ }
67
+
68
+ eval_latest() {
69
+ local config="$1"
70
+ local run_name="$2"
71
+ local target_step="$3"
72
+ local out_dir="${OUT_ROOT}/${config}/step_${target_step}"
73
+ mkdir -p "${out_dir}"
74
+ CUDA_VISIBLE_DEVICES="${EVAL_CUDA_VISIBLE_DEVICES:-0}" python scripts/eval_train8_decode_acc.py \
75
+ --runs_glob "runs/${run_name}" \
76
+ --data_dir "${cache}" \
77
+ --tokenizer_path "${TOKENIZER_PATH}" \
78
+ --out_dir "${out_dir}" \
79
+ --max_len "${MAX_LEN}" \
80
+ --n_samples "${N_SAMPLES}" \
81
+ --batch_size "${N_SAMPLES}" \
82
+ --latest_only \
83
+ --endpoint_softenings none \
84
+ --steps "${INFER_STEPS}" \
85
+ --decode_rule flowmap \
86
+ --time_schedule logit_normal \
87
+ --time_logit_mean -1.5 \
88
+ --time_logit_std 0.8 \
89
+ --model_t_mode post \
90
+ --c_min 1 \
91
+ --c_max 512 \
92
+ --late_temp 1.0 \
93
+ --final_from state \
94
+ --final_decode argmax
95
+ python - "$out_dir" "$N_SAMPLES" "$GLOBAL_BATCH_SIZE" "$MAX_LEN" "$CURVE_CSV" "$config" "$run_name" <<'PY'
96
+ import json
97
+ import sys
98
+ from pathlib import Path
99
+ out = Path(sys.argv[1])
100
+ n = int(sys.argv[2])
101
+ global_batch = int(sys.argv[3])
102
+ max_len = int(sys.argv[4])
103
+ curve = Path(sys.argv[5])
104
+ config = sys.argv[6]
105
+ run_name = sys.argv[7]
106
+ row = json.loads((out / "decode_token_acc.jsonl").read_text().splitlines()[-1])
107
+ views = int(row["ckpt_step"]) * global_batch
108
+ tokens = views * max_len
109
+ print(
110
+ "RESULT "
111
+ f"config={config} run={run_name} ckpt_step={row['ckpt_step']} "
112
+ f"views={views} token_acc={row['token_acc_mean']:.4f} "
113
+ f"exact={row['exact_count']}/{n} exact_refs={row['exact_ref_count']} "
114
+ f"hits={row['exact_ref_hits']}",
115
+ flush=True,
116
+ )
117
+ with curve.open("a", encoding="utf-8") as f:
118
+ f.write(
119
+ f"{config},{run_name},{row['ckpt_step']},{views},{tokens},{row['token_acc_mean']},"
120
+ f"{row['exact_count']},{row['exact_ref_count']},\"{row['exact_ref_hits']}\"\n"
121
+ )
122
+ PY
123
+ }
124
+
125
+ run_config() {
126
+ local config="$1"
127
+ local rollout_prob="$2"
128
+ local rollout_steps="$3"
129
+ local rollout_infer_steps="$4"
130
+ local output_wd="$5"
131
+ local sync_t="$6"
132
+ local run_name="${RUN_PREFIX:-train8_ctx1024}_${config}_${GROUP_STAMP}"
133
+ echo "[ctx1024-sweep] config=${config} run=${run_name} p=${rollout_prob} steps=${rollout_steps} infer=${rollout_infer_steps} outwd=${output_wd} sync_t=${sync_t}" | tee -a "${DRIVER_LOG}"
134
+ while :; do
135
+ local step_now
136
+ step_now="$(latest_step "${run_name}")"
137
+ if [[ "${step_now}" -ge "${MAX_TOTAL_STEPS}" ]]; then
138
+ echo "[ctx1024-sweep] capped config=${config} step=${step_now}" | tee -a "${DRIVER_LOG}"
139
+ break
140
+ fi
141
+ local target_step=$((step_now + STEP_CHUNK))
142
+ if [[ "${target_step}" -gt "${MAX_TOTAL_STEPS}" ]]; then
143
+ target_step="${MAX_TOTAL_STEPS}"
144
+ fi
145
+ local resume_path=""
146
+ if [[ -f "runs/${run_name}/latest.pt" ]]; then
147
+ resume_path="runs/${run_name}/latest.pt"
148
+ fi
149
+ echo "[ctx1024-sweep] train config=${config} from=${step_now} to=${target_step}" | tee -a "${DRIVER_LOG}"
150
+ CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}" \
151
+ NPROC_PER_NODE="${NPROC_PER_NODE:-4}" \
152
+ MASTER_PORT="$(free_port)" \
153
+ OWT_CHUNK_CACHE_DIR="${cache}" \
154
+ OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK:-64}" \
155
+ MAX_LEN="${MAX_LEN}" \
156
+ VOCAB_SIZE_OVERRIDE="${vocab_size}" \
157
+ D_MODEL="${D_MODEL:-192}" \
158
+ COND_DIM="${COND_DIM:-64}" \
159
+ N_LAYERS="${N_LAYERS:-3}" \
160
+ N_HEADS="${N_HEADS:-3}" \
161
+ DIM_FF="${DIM_FF:-768}" \
162
+ TOTAL_STEPS="${target_step}" \
163
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE}" \
164
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE}" \
165
+ NUM_WORKERS="${NUM_WORKERS:-0}" \
166
+ LOG_EVERY="${LOG_EVERY:-100}" \
167
+ SAVE_EVERY="${STEP_CHUNK}" \
168
+ LATEST_EVERY="${STEP_CHUNK}" \
169
+ WARMUP_STEPS="${WARMUP_STEPS:-10}" \
170
+ LEARNING_RATE="${LEARNING_RATE}" \
171
+ WEIGHT_DECAY="${WEIGHT_DECAY}" \
172
+ OUTPUT_WEIGHT_DECAY="${output_wd}" \
173
+ MUON_IMPL="${MUON_IMPL:-legacy}" \
174
+ MIN_MASK_RATIO=1.0 \
175
+ MAX_MASK_RATIO=1.0 \
176
+ MASK_MIXTURE_LOWK_PROB=0.0 \
177
+ MASK_MIXTURE_ALL_PROB=1.0 \
178
+ LOWK_CLEAN_TOKENS=0 \
179
+ CLEAN_STATE_MODE=onehot \
180
+ TARGET_LOSS=hard_ce \
181
+ DIRICHLET_CONCENTRATION_MIN=1.0 \
182
+ DIRICHLET_CONCENTRATION_MAX=1024 \
183
+ SIMPLEX_BRIDGE_SAMPLER=dirichlet \
184
+ CATEGORICAL_WRONG_PROB_FLOOR=0.0 \
185
+ ROLLOUT_TRAIN_PROB="${rollout_prob}" \
186
+ ROLLOUT_TRAIN_STEPS="${rollout_steps}" \
187
+ ROLLOUT_TRAIN_INFER_STEPS="${rollout_infer_steps}" \
188
+ ROLLOUT_TRAIN_TEMP=1.45 \
189
+ ROLLOUT_TRAIN_MAX_GAMMA=1.0 \
190
+ ROLLOUT_TRAIN_CORRUPT_ONLY=1 \
191
+ ROLLOUT_TRAIN_SAMPLEWISE=1 \
192
+ ROLLOUT_TRAIN_SELECTED_ONLY=1 \
193
+ ROLLOUT_TRAIN_COMPUTE_ALWAYS=0 \
194
+ ROLLOUT_TRAIN_SYNC_T="${sync_t}" \
195
+ RUN_NAME="${run_name}" \
196
+ RESUME_PATH="${resume_path}" \
197
+ bash scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh
198
+ echo "[ctx1024-sweep] eval config=${config} step=${target_step}" | tee -a "${DRIVER_LOG}"
199
+ eval_latest "${config}" "${run_name}" "${target_step}" | tee -a "${DRIVER_LOG}"
200
+ if tail -n 1 "${CURVE_CSV}" | awk -F, '{exit ($7 >= 60) ? 0 : 1}'; then
201
+ echo "[ctx1024-sweep] early-hit config=${config}" | tee -a "${DRIVER_LOG}"
202
+ break
203
+ fi
204
+ done
205
+ }
206
+
207
+ if [[ -n "${WAIT_FOR_RUN}" ]]; then
208
+ echo "[ctx1024-sweep] waiting for run=${WAIT_FOR_RUN}" | tee -a "${DRIVER_LOG}"
209
+ while pgrep -f "${WAIT_FOR_RUN}" >/dev/null; do
210
+ sleep 60
211
+ done
212
+ fi
213
+
214
+ echo "[ctx1024-sweep] start stamp=${GROUP_STAMP} len=${MAX_LEN} vocab=${vocab_size} out=${OUT_ROOT}" | tee -a "${DRIVER_LOG}"
215
+
216
+ CONFIGS=(
217
+ "p75_s4_i32_outwdm1|0.75|4|32|-1|0"
218
+ "p100_s4_i32_outwdm1|1.0|4|32|-1|0"
219
+ "p50_s8_i64_outwdm1|0.50|8|64|-1|0"
220
+ "p50_s4_i32_sync1_outwdm1|0.50|4|32|-1|1"
221
+ "p50_s8_i64_sync1_outwdm1|0.50|8|64|-1|1"
222
+ "p50_s4_i32_outwd0p3|0.50|4|32|0.3|0"
223
+ "p50_s4_i32_outwd1p0|0.50|4|32|1.0|0"
224
+ )
225
+
226
+ if [[ -n "${SWEEP_CONFIGS:-}" ]]; then
227
+ IFS=$'\n' read -r -d '' -a CONFIGS < <(printf '%s\0' "${SWEEP_CONFIGS}") || true
228
+ fi
229
+
230
+ for entry in "${CONFIGS[@]}"; do
231
+ IFS='|' read -r config rollout_prob rollout_steps rollout_infer_steps output_wd sync_t <<<"${entry}"
232
+ run_config "${config}" "${rollout_prob}" "${rollout_steps}" "${rollout_infer_steps}" "${output_wd}" "${sync_t}"
233
+ done
234
+
235
+ echo "[ctx1024-sweep] done" | tee -a "${DRIVER_LOG}"
LTA_openwebtext_dualt/scripts/run_train8_len256_headwd_sweep_bs512_until_exact_4gpu.sh ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
7
+ export TOKENIZERS_PARALLELISM=false
8
+ export PYTHONUNBUFFERED=1
9
+
10
+ BASE_CACHE="${BASE_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks}"
11
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
12
+ MAX_LEN="${MAX_LEN:-256}"
13
+ N_SAMPLES="${N_SAMPLES:-64}"
14
+ STOP_EXACT_COUNT="${STOP_EXACT_COUNT:-1}"
15
+ STOP_EXACT_REF_COUNT="${STOP_EXACT_REF_COUNT:-0}"
16
+ INFER_STEPS="${INFER_STEPS:-128}"
17
+ STEP_CHUNK="${STEP_CHUNK:-500}"
18
+ MAX_TOTAL_STEPS="${MAX_TOTAL_STEPS:-100000}"
19
+ MASTER_PORT_BASE="${MASTER_PORT_BASE:-38200}"
20
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-128}"
21
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
22
+ OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK:-64}"
23
+ LEARNING_RATE="${LEARNING_RATE:-0.002}"
24
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
25
+ HEAD_WD_VALUES="${HEAD_WD_VALUES:-0.1 0.3 1.0 3.0}"
26
+ OUT_ROOT="${OUT_ROOT:-docs/lta_samples/metrics_20260517/train8_len256_headwd_sweep_bs512_until_exact_ode128}"
27
+ DRIVER_LOG="${DRIVER_LOG:-logs/train8_len256_headwd_sweep_bs512_until_exact_4gpu/driver.log}"
28
+ SUMMARY_JSONL="${SUMMARY_JSONL:-${OUT_ROOT}/summary_results.jsonl}"
29
+ CURVE_CSV="${CURVE_CSV:-${OUT_ROOT}/hit_ratio_curve.csv}"
30
+ mkdir -p "$(dirname "${DRIVER_LOG}")" "${OUT_ROOT}"
31
+
32
+ tag_float() {
33
+ printf "%s" "$1" | sed -e 's/-/m/g' -e 's/\./p/g'
34
+ }
35
+
36
+ vocab_size_for_len() {
37
+ python - "$BASE_CACHE" "$MAX_LEN" <<'PY'
38
+ import json
39
+ import sys
40
+ from pathlib import Path
41
+ base = Path(sys.argv[1])
42
+ max_len = int(sys.argv[2])
43
+ meta = json.loads((base / f"gpt2_len{max_len}_train8_compact_overfit" / "meta.json").read_text())
44
+ print(int(meta.get("compact_vocab_size", meta.get("vocab_size"))))
45
+ PY
46
+ }
47
+
48
+ latest_step() {
49
+ local run_name="$1"
50
+ python - "$run_name" <<'PY'
51
+ import re
52
+ import sys
53
+ from pathlib import Path
54
+ run = Path("runs") / sys.argv[1]
55
+ steps = []
56
+ for path in run.glob("step_*.pt"):
57
+ m = re.search(r"step_(\d+)\.pt$", path.name)
58
+ if m:
59
+ steps.append(int(m.group(1)))
60
+ print(max(steps) if steps else 0)
61
+ PY
62
+ }
63
+
64
+ refresh_curve() {
65
+ python - "$SUMMARY_JSONL" "$CURVE_CSV" <<'PY'
66
+ import csv
67
+ import json
68
+ import re
69
+ import sys
70
+ from pathlib import Path
71
+
72
+ summary = Path(sys.argv[1])
73
+ out = Path(sys.argv[2])
74
+ rows = []
75
+ seen = set()
76
+ if summary.exists():
77
+ for line in summary.read_text().splitlines():
78
+ if not line.strip():
79
+ continue
80
+ row = json.loads(line)
81
+ m = re.search(
82
+ r"train8_n(\d+)_compactv\d+_small_headwd([^_]+)_bs512_hard_ce_(onehot|allcorrupt)",
83
+ row.get("run", ""),
84
+ )
85
+ if not m:
86
+ continue
87
+ max_len = int(m.group(1))
88
+ head_wd = m.group(2).replace("p", ".").replace("m", "-")
89
+ mode = m.group(3)
90
+ step = int(row["ckpt_step"])
91
+ key = (max_len, head_wd, mode, step)
92
+ if key in seen:
93
+ continue
94
+ seen.add(key)
95
+ n_gen = float(row.get("n_gen", 64))
96
+ rows.append({
97
+ "context_len": max_len,
98
+ "head_weight_decay": head_wd,
99
+ "mode": mode,
100
+ "step": step,
101
+ "train_views_seen": int(row.get("train_views_seen", step * 512)),
102
+ "train_tokens_seen": int(row.get("train_tokens_seen", step * 512 * max_len)),
103
+ "hit_ratio": float(row.get("exact_count", 0)) / n_gen,
104
+ "exact_count": int(row.get("exact_count", 0)),
105
+ "ref_coverage": float(row.get("exact_ref_count", 0)) / 8.0,
106
+ "exact_ref_count": int(row.get("exact_ref_count", 0)),
107
+ "token_acc_mean": float(row.get("token_acc_mean", 0.0)),
108
+ })
109
+ rows.sort(key=lambda x: (x["context_len"], float(x["head_weight_decay"]), x["mode"], x["step"]))
110
+ out.parent.mkdir(parents=True, exist_ok=True)
111
+ with out.open("w", newline="", encoding="utf-8") as f:
112
+ fieldnames = [
113
+ "context_len",
114
+ "head_weight_decay",
115
+ "mode",
116
+ "step",
117
+ "train_views_seen",
118
+ "train_tokens_seen",
119
+ "hit_ratio",
120
+ "exact_count",
121
+ "ref_coverage",
122
+ "exact_ref_count",
123
+ "token_acc_mean",
124
+ ]
125
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
126
+ writer.writeheader()
127
+ writer.writerows(rows)
128
+ print(f"curve_rows={len(rows)}")
129
+ PY
130
+ }
131
+
132
+ eval_latest() {
133
+ local run_name="$1"
134
+ local target_step="$2"
135
+ local head_wd="$3"
136
+ local cache="$4"
137
+ local out_dir="${OUT_ROOT}/len${MAX_LEN}/headwd_$(tag_float "${head_wd}")/${run_name}/step_${target_step}"
138
+ mkdir -p "${out_dir}"
139
+ CUDA_VISIBLE_DEVICES="${EVAL_CUDA_VISIBLE_DEVICES:-0}" python scripts/eval_train8_decode_acc.py \
140
+ --runs_glob "runs/${run_name}" \
141
+ --data_dir "${cache}" \
142
+ --tokenizer_path "${TOKENIZER_PATH}" \
143
+ --out_dir "${out_dir}" \
144
+ --max_len "${MAX_LEN}" \
145
+ --n_samples "${N_SAMPLES}" \
146
+ --batch_size "${N_SAMPLES}" \
147
+ --latest_only \
148
+ --endpoint_softenings none \
149
+ --steps "${INFER_STEPS}" \
150
+ --decode_rule flowmap \
151
+ --time_schedule logit_normal \
152
+ --time_logit_mean -1.5 \
153
+ --time_logit_std 0.8 \
154
+ --model_t_mode post \
155
+ --c_min 1 \
156
+ --c_max 512 \
157
+ --late_temp 1.0 \
158
+ --final_from state \
159
+ --final_decode argmax
160
+ python - "$out_dir" "$N_SAMPLES" "$GLOBAL_BATCH_SIZE" "$MAX_LEN" "$SUMMARY_JSONL" "$STOP_EXACT_COUNT" "$STOP_EXACT_REF_COUNT" "$head_wd" <<'PY'
161
+ import json
162
+ import sys
163
+ from pathlib import Path
164
+
165
+ out = Path(sys.argv[1])
166
+ n = int(sys.argv[2])
167
+ global_batch = int(sys.argv[3])
168
+ max_len = int(sys.argv[4])
169
+ summary = Path(sys.argv[5])
170
+ stop_exact_count = int(sys.argv[6])
171
+ stop_exact_ref_count = int(sys.argv[7])
172
+ head_wd = float(sys.argv[8])
173
+ row = json.loads((out / "decode_token_acc.jsonl").read_text().splitlines()[-1])
174
+ row["head_weight_decay"] = head_wd
175
+ row["train_views_seen"] = int(row["ckpt_step"]) * global_batch
176
+ row["train_tokens_seen"] = row["train_views_seen"] * max_len
177
+ line = (
178
+ "RESULT "
179
+ f"len={max_len} "
180
+ f"head_wd={head_wd:g} "
181
+ f"run={row['run']} "
182
+ f"ckpt_step={row['ckpt_step']} "
183
+ f"train_views_seen={row['train_views_seen']} "
184
+ f"train_tokens_seen={row['train_tokens_seen']} "
185
+ f"token_acc_mean={row['token_acc_mean']:.4f} "
186
+ f"exact_count={row['exact_count']}/{n} "
187
+ f"exact_ref_count={row['exact_ref_count']} "
188
+ f"exact_refs={row['exact_ref_hits']}"
189
+ )
190
+ print(line, flush=True)
191
+ summary.parent.mkdir(parents=True, exist_ok=True)
192
+ with summary.open("a", encoding="utf-8") as f:
193
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
194
+ raise SystemExit(
195
+ 0
196
+ if int(row["exact_count"]) >= stop_exact_count
197
+ and int(row["exact_ref_count"]) >= stop_exact_ref_count
198
+ else 1
199
+ )
200
+ PY
201
+ }
202
+
203
+ cache="${BASE_CACHE}/gpt2_len${MAX_LEN}_train8_compact_overfit"
204
+ vocab_size="$(vocab_size_for_len)"
205
+
206
+ mode_specs=(
207
+ "onehot|1|0|1,2,4"
208
+ "allcorrupt|0|1|0"
209
+ )
210
+
211
+ for head_wd in ${HEAD_WD_VALUES}; do
212
+ head_tag="$(tag_float "${head_wd}")"
213
+ echo "[headwd-sweep] head_wd=${head_wd} $(date)" | tee -a "${DRIVER_LOG}"
214
+ for mode_spec in "${mode_specs[@]}"; do
215
+ IFS='|' read -r mode lowk_prob all_prob lowk_tokens <<<"${mode_spec}"
216
+ run_name="train8_n${MAX_LEN}_compactv${vocab_size}_small_headwd${head_tag}_bs512_hard_ce_${mode}"
217
+ done_flag="${OUT_ROOT}/len${MAX_LEN}/headwd_${head_tag}/${run_name}/DONE"
218
+ while :; do
219
+ if [[ -f "${done_flag}" ]]; then
220
+ echo "[headwd-sweep] skip done head_wd=${head_wd} mode=${mode} run=${run_name}" | tee -a "${DRIVER_LOG}"
221
+ break
222
+ fi
223
+
224
+ step_now="$(latest_step "${run_name}")"
225
+ if [[ "${step_now}" -ge "${MAX_TOTAL_STEPS}" ]]; then
226
+ echo "[headwd-sweep] hit max_total_steps head_wd=${head_wd} mode=${mode} step=${step_now}" | tee -a "${DRIVER_LOG}"
227
+ break
228
+ fi
229
+ target_step=$((step_now + STEP_CHUNK))
230
+ if [[ "${target_step}" -gt "${MAX_TOTAL_STEPS}" ]]; then
231
+ target_step="${MAX_TOTAL_STEPS}"
232
+ fi
233
+
234
+ resume_path=""
235
+ if [[ -f "runs/${run_name}/latest.pt" ]]; then
236
+ resume_path="runs/${run_name}/latest.pt"
237
+ fi
238
+
239
+ echo "[headwd-sweep] train head_wd=${head_wd} mode=${mode} run=${run_name} from=${step_now} to=${target_step}" | tee -a "${DRIVER_LOG}"
240
+ CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}" \
241
+ NPROC_PER_NODE="${NPROC_PER_NODE:-4}" \
242
+ MASTER_PORT=$((MASTER_PORT_BASE + target_step / STEP_CHUNK)) \
243
+ OWT_CHUNK_CACHE_DIR="${cache}" \
244
+ OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK}" \
245
+ MAX_LEN="${MAX_LEN}" \
246
+ VOCAB_SIZE_OVERRIDE="${vocab_size}" \
247
+ D_MODEL="${D_MODEL:-192}" \
248
+ COND_DIM="${COND_DIM:-64}" \
249
+ N_LAYERS="${N_LAYERS:-3}" \
250
+ N_HEADS="${N_HEADS:-3}" \
251
+ DIM_FF="${DIM_FF:-768}" \
252
+ LOWK_CLEAN_TOKENS="${lowk_tokens}" \
253
+ MASK_MIXTURE_LOWK_PROB="${lowk_prob}" \
254
+ MASK_MIXTURE_ALL_PROB="${all_prob}" \
255
+ TOTAL_STEPS="${target_step}" \
256
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE}" \
257
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE}" \
258
+ NUM_WORKERS="${NUM_WORKERS:-0}" \
259
+ LOG_EVERY="${LOG_EVERY:-50}" \
260
+ SAVE_EVERY="${STEP_CHUNK}" \
261
+ LATEST_EVERY="${STEP_CHUNK}" \
262
+ WARMUP_STEPS="${WARMUP_STEPS:-10}" \
263
+ LEARNING_RATE="${LEARNING_RATE}" \
264
+ WEIGHT_DECAY="${WEIGHT_DECAY}" \
265
+ OUTPUT_WEIGHT_DECAY="${head_wd}" \
266
+ MUON_IMPL="${MUON_IMPL:-optax}" \
267
+ TARGET_LOSS="hard_ce" \
268
+ CLEAN_STATE_MODE="onehot" \
269
+ RUN_NAME="${run_name}" \
270
+ RESUME_PATH="${resume_path}" \
271
+ bash scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh
272
+
273
+ echo "[headwd-sweep] eval head_wd=${head_wd} mode=${mode} run=${run_name} step=${target_step}" | tee -a "${DRIVER_LOG}"
274
+ if eval_latest "${run_name}" "${target_step}" "${head_wd}" "${cache}" | tee -a "${DRIVER_LOG}"; then
275
+ echo "[headwd-sweep] DONE head_wd=${head_wd} mode=${mode} run=${run_name} step=${target_step}" | tee -a "${DRIVER_LOG}"
276
+ touch "${done_flag}"
277
+ refresh_curve | tee -a "${DRIVER_LOG}"
278
+ break
279
+ else
280
+ echo "[headwd-sweep] continue head_wd=${head_wd} mode=${mode} run=${run_name} step=${target_step}" | tee -a "${DRIVER_LOG}"
281
+ refresh_curve | tee -a "${DRIVER_LOG}"
282
+ fi
283
+ done
284
+ done
285
+ done
286
+
287
+ refresh_curve | tee -a "${DRIVER_LOG}"
LTA_openwebtext_dualt/scripts/run_train8_noise_geometry_pilots_4gpu.sh ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
5
+
6
+ export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
7
+ export TOKENIZERS_PARALLELISM=false
8
+ export PYTHONUNBUFFERED=1
9
+
10
+ BASE_CACHE="${BASE_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks}"
11
+ TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
12
+ MAX_LEN="${MAX_LEN:-256}"
13
+ N_SAMPLES="${N_SAMPLES:-64}"
14
+ INFER_STEPS="${INFER_STEPS:-128}"
15
+ STEP_CHUNK="${STEP_CHUNK:-1000}"
16
+ MAX_TOTAL_STEPS="${MAX_TOTAL_STEPS:-30000}"
17
+ STOP_EXACT_COUNT="${STOP_EXACT_COUNT:-64}"
18
+ STOP_EXACT_REF_COUNT="${STOP_EXACT_REF_COUNT:-8}"
19
+ MASTER_PORT_BASE="${MASTER_PORT_BASE:-32950}"
20
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-128}"
21
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
22
+ GROUP_STAMP="${GROUP_STAMP:-$(date +%Y%m%d_%H%M%S)}"
23
+ OUT_ROOT="${OUT_ROOT:-docs/lta_samples/metrics_20260517/noise_geometry_pilots_len${MAX_LEN}_bs512_ode128_${GROUP_STAMP}}"
24
+ DRIVER_LOG="${DRIVER_LOG:-logs/noise_geometry_pilots_4gpu/${GROUP_STAMP}.log}"
25
+ SUMMARY_JSONL="${SUMMARY_JSONL:-${OUT_ROOT}/summary_results.jsonl}"
26
+ CURVE_CSV="${CURVE_CSV:-${OUT_ROOT}/hit_ratio_curve.csv}"
27
+ mkdir -p "$(dirname "${DRIVER_LOG}")" "${OUT_ROOT}"
28
+
29
+ cache="${BASE_CACHE}/gpt2_len${MAX_LEN}_train8_compact_overfit"
30
+ vocab_size="$(
31
+ python - "$cache" <<'PY'
32
+ import json
33
+ import sys
34
+ from pathlib import Path
35
+ meta = json.loads((Path(sys.argv[1]) / "meta.json").read_text())
36
+ print(int(meta.get("compact_vocab_size", meta.get("vocab_size"))))
37
+ PY
38
+ )"
39
+
40
+ if [[ ! -f "${CURVE_CSV}" ]]; then
41
+ echo "config,ckpt_step,train_views_seen,train_tokens_seen,token_acc_mean,exact_count,exact_ref_count,exact_ref_hits" > "${CURVE_CSV}"
42
+ fi
43
+
44
+ latest_step() {
45
+ local run_name="$1"
46
+ python - "$run_name" <<'PY'
47
+ import re
48
+ import sys
49
+ from pathlib import Path
50
+ run = Path("runs") / sys.argv[1]
51
+ steps = []
52
+ for path in run.glob("step_*.pt"):
53
+ m = re.search(r"step_(\d+)\.pt$", path.name)
54
+ if m:
55
+ steps.append(int(m.group(1)))
56
+ print(max(steps) if steps else 0)
57
+ PY
58
+ }
59
+
60
+ free_port() {
61
+ python - <<'PY'
62
+ import socket
63
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
64
+ s.bind(("127.0.0.1", 0))
65
+ print(s.getsockname()[1])
66
+ PY
67
+ }
68
+
69
+ eval_latest() {
70
+ local config="$1"
71
+ local run_name="$2"
72
+ local target_step="$3"
73
+ local out_dir="${OUT_ROOT}/${config}/step_${target_step}"
74
+ mkdir -p "${out_dir}"
75
+ CUDA_VISIBLE_DEVICES="${EVAL_CUDA_VISIBLE_DEVICES:-0}" python scripts/eval_train8_decode_acc.py \
76
+ --runs_glob "runs/${run_name}" \
77
+ --data_dir "${cache}" \
78
+ --tokenizer_path "${TOKENIZER_PATH}" \
79
+ --out_dir "${out_dir}" \
80
+ --max_len "${MAX_LEN}" \
81
+ --n_samples "${N_SAMPLES}" \
82
+ --batch_size "${N_SAMPLES}" \
83
+ --latest_only \
84
+ --endpoint_softenings none \
85
+ --steps "${INFER_STEPS}" \
86
+ --decode_rule flowmap \
87
+ --time_schedule logit_normal \
88
+ --time_logit_mean -1.5 \
89
+ --time_logit_std 0.8 \
90
+ --model_t_mode post \
91
+ --c_min 1 \
92
+ --c_max 512 \
93
+ --late_temp 1.0 \
94
+ --final_from state \
95
+ --final_decode argmax
96
+ python - "$config" "$out_dir" "$N_SAMPLES" "$GLOBAL_BATCH_SIZE" "$MAX_LEN" "$SUMMARY_JSONL" "$CURVE_CSV" "$STOP_EXACT_COUNT" "$STOP_EXACT_REF_COUNT" <<'PY'
97
+ import json
98
+ import sys
99
+ from pathlib import Path
100
+
101
+ config = sys.argv[1]
102
+ out = Path(sys.argv[2])
103
+ n = int(sys.argv[3])
104
+ global_batch = int(sys.argv[4])
105
+ max_len = int(sys.argv[5])
106
+ summary = Path(sys.argv[6])
107
+ curve = Path(sys.argv[7])
108
+ stop_exact_count = int(sys.argv[8])
109
+ stop_exact_ref_count = int(sys.argv[9])
110
+ row = json.loads((out / "decode_token_acc.jsonl").read_text().splitlines()[-1])
111
+ row["config"] = config
112
+ row["train_views_seen"] = int(row["ckpt_step"]) * global_batch
113
+ row["train_tokens_seen"] = row["train_views_seen"] * max_len
114
+ line = (
115
+ "RESULT "
116
+ f"config={config} "
117
+ f"ckpt_step={row['ckpt_step']} "
118
+ f"views={row['train_views_seen']} "
119
+ f"token_acc={row['token_acc_mean']:.4f} "
120
+ f"exact={row['exact_count']}/{n} "
121
+ f"exact_refs={row['exact_ref_count']} "
122
+ f"hits={row['exact_ref_hits']}"
123
+ )
124
+ print(line, flush=True)
125
+ summary.parent.mkdir(parents=True, exist_ok=True)
126
+ with summary.open("a", encoding="utf-8") as f:
127
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
128
+ with curve.open("a", encoding="utf-8") as f:
129
+ f.write(
130
+ f"{config},{row['ckpt_step']},{row['train_views_seen']},{row['train_tokens_seen']},"
131
+ f"{row['token_acc_mean']},{row['exact_count']},{row['exact_ref_count']},"
132
+ f"\"{row['exact_ref_hits']}\"\n"
133
+ )
134
+ raise SystemExit(
135
+ 0
136
+ if int(row["exact_count"]) >= stop_exact_count
137
+ and int(row["exact_ref_count"]) >= stop_exact_ref_count
138
+ else 1
139
+ )
140
+ PY
141
+ }
142
+
143
+ reset_noise_defaults() {
144
+ export MIN_MASK_RATIO=1.0
145
+ export MAX_MASK_RATIO=1.0
146
+ export MASK_MIXTURE_LOWK_PROB=0.0
147
+ export MASK_MIXTURE_ALL_PROB=1.0
148
+ export LOWK_CLEAN_TOKENS=0
149
+ export CLEAN_STATE_MODE=onehot
150
+ export TARGET_LOSS=hard_ce
151
+ export LINEAR_SOFT_TARGET_POWER=1.0
152
+ export LINEAR_SOFT_TARGET_MIN_CONF=0.0
153
+ export LINEAR_SOFT_TARGET_MAX_CONF=1.0
154
+ export DIRICHLET_CONCENTRATION_MIN=1.0
155
+ export DIRICHLET_CONCENTRATION_MAX=1024
156
+ export SIMPLEX_BRIDGE_SAMPLER=dirichlet
157
+ export ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA=0.0
158
+ export CATEGORICAL_WRONG_BASIN_TOKEN_IDS=""
159
+ export CATEGORICAL_WRONG_BASIN_PROB=0.0
160
+ export CATEGORICAL_WRONG_UNIGRAM_PROB=0.0
161
+ export CATEGORICAL_WRONG_UNIFORM_PROB=0.0
162
+ export CATEGORICAL_WRONG_BASIN_SHARED_PROB=0.0
163
+ export CATEGORICAL_WRONG_UNIGRAM_SHARED_PROB=0.0
164
+ export LOGISTIC_NORMAL_SIGMA_MIN=0.1
165
+ export LOGISTIC_NORMAL_SIGMA_MAX=1.0
166
+ export LOGISTIC_NORMAL_TAU_MIN=1.0
167
+ export LOGISTIC_NORMAL_TAU_MAX=1.0
168
+ }
169
+
170
+ configure_noise() {
171
+ local config="$1"
172
+ reset_noise_defaults
173
+ case "${config}" in
174
+ allcorrupt_fullvocab_dirC1_1024)
175
+ ;;
176
+ allcorrupt_seqrand1_dirC1_1024)
177
+ export ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA=1.0
178
+ ;;
179
+ allcorrupt_highC64_4096)
180
+ export DIRICHLET_CONCENTRATION_MIN=64
181
+ export DIRICHLET_CONCENTRATION_MAX=4096
182
+ ;;
183
+ allcorrupt_logistic_sig0p05_0p5)
184
+ export SIMPLEX_BRIDGE_SAMPLER=logistic_normal_linear_mean
185
+ export LOGISTIC_NORMAL_SIGMA_MIN=0.05
186
+ export LOGISTIC_NORMAL_SIGMA_MAX=0.5
187
+ ;;
188
+ allcorrupt_seqrand1_highC64_4096)
189
+ export ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA=1.0
190
+ export DIRICHLET_CONCENTRATION_MIN=64
191
+ export DIRICHLET_CONCENTRATION_MAX=4096
192
+ ;;
193
+ allcorrupt_unigram_shared0p5)
194
+ export CATEGORICAL_WRONG_UNIGRAM_PROB=1.0
195
+ export CATEGORICAL_WRONG_UNIGRAM_SHARED_PROB=0.5
196
+ ;;
197
+ *)
198
+ echo "unknown config: ${config}" >&2
199
+ return 2
200
+ ;;
201
+ esac
202
+ }
203
+
204
+ configs=(
205
+ allcorrupt_fullvocab_dirC1_1024
206
+ allcorrupt_seqrand1_dirC1_1024
207
+ allcorrupt_highC64_4096
208
+ allcorrupt_logistic_sig0p05_0p5
209
+ allcorrupt_seqrand1_highC64_4096
210
+ allcorrupt_unigram_shared0p5
211
+ )
212
+
213
+ echo "[noise-pilot] start stamp=${GROUP_STAMP} len=${MAX_LEN} vocab=${vocab_size} out=${OUT_ROOT}" | tee -a "${DRIVER_LOG}"
214
+
215
+ round_idx=0
216
+ while :; do
217
+ round_idx=$((round_idx + 1))
218
+ active=0
219
+ echo "[noise-pilot] round=${round_idx} $(date)" | tee -a "${DRIVER_LOG}"
220
+
221
+ idx=0
222
+ for config in "${configs[@]}"; do
223
+ idx=$((idx + 1))
224
+ run_name="train8_noisegeo_len${MAX_LEN}_${config}_${GROUP_STAMP}"
225
+ done_flag="${OUT_ROOT}/${config}/DONE"
226
+ if [[ -f "${done_flag}" ]]; then
227
+ echo "[noise-pilot] skip done config=${config}" | tee -a "${DRIVER_LOG}"
228
+ continue
229
+ fi
230
+ step_now="$(latest_step "${run_name}")"
231
+ if [[ "${step_now}" -ge "${MAX_TOTAL_STEPS}" ]]; then
232
+ echo "[noise-pilot] capped config=${config} step=${step_now}" | tee -a "${DRIVER_LOG}"
233
+ continue
234
+ fi
235
+ active=1
236
+ target_step=$((step_now + STEP_CHUNK))
237
+ if [[ "${target_step}" -gt "${MAX_TOTAL_STEPS}" ]]; then
238
+ target_step="${MAX_TOTAL_STEPS}"
239
+ fi
240
+ resume_path=""
241
+ if [[ -f "runs/${run_name}/latest.pt" ]]; then
242
+ resume_path="runs/${run_name}/latest.pt"
243
+ fi
244
+
245
+ configure_noise "${config}"
246
+ echo "[noise-pilot] train config=${config} from=${step_now} to=${target_step} C=${DIRICHLET_CONCENTRATION_MIN}->${DIRICHLET_CONCENTRATION_MAX} seq_alpha=${ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA} sampler=${SIMPLEX_BRIDGE_SAMPLER}" | tee -a "${DRIVER_LOG}"
247
+
248
+ CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}" \
249
+ NPROC_PER_NODE="${NPROC_PER_NODE:-4}" \
250
+ MASTER_PORT="$(free_port)" \
251
+ OWT_CHUNK_CACHE_DIR="${cache}" \
252
+ OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK:-64}" \
253
+ MAX_LEN="${MAX_LEN}" \
254
+ VOCAB_SIZE_OVERRIDE="${vocab_size}" \
255
+ D_MODEL="${D_MODEL:-192}" \
256
+ COND_DIM="${COND_DIM:-64}" \
257
+ N_LAYERS="${N_LAYERS:-3}" \
258
+ N_HEADS="${N_HEADS:-3}" \
259
+ DIM_FF="${DIM_FF:-768}" \
260
+ TOTAL_STEPS="${target_step}" \
261
+ PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE}" \
262
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE}" \
263
+ NUM_WORKERS="${NUM_WORKERS:-0}" \
264
+ LOG_EVERY="${LOG_EVERY:-100}" \
265
+ SAVE_EVERY="${STEP_CHUNK}" \
266
+ LATEST_EVERY="${STEP_CHUNK}" \
267
+ WARMUP_STEPS="${WARMUP_STEPS:-10}" \
268
+ LEARNING_RATE="${LEARNING_RATE:-0.002}" \
269
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}" \
270
+ MUON_IMPL="${MUON_IMPL:-legacy}" \
271
+ OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}" \
272
+ RUN_NAME="${run_name}" \
273
+ RESUME_PATH="${resume_path}" \
274
+ bash scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh
275
+
276
+ echo "[noise-pilot] eval config=${config} step=${target_step}" | tee -a "${DRIVER_LOG}"
277
+ if eval_latest "${config}" "${run_name}" "${target_step}" | tee -a "${DRIVER_LOG}"; then
278
+ echo "[noise-pilot] DONE config=${config} step=${target_step}" | tee -a "${DRIVER_LOG}"
279
+ touch "${done_flag}"
280
+ else
281
+ echo "[noise-pilot] continue config=${config} step=${target_step}" | tee -a "${DRIVER_LOG}"
282
+ fi
283
+ done
284
+
285
+ if [[ "${active}" -eq 0 ]]; then
286
+ echo "[noise-pilot] all configs done/capped $(date)" | tee -a "${DRIVER_LOG}"
287
+ break
288
+ fi
289
+ done
LTA_openwebtext_dualt/scripts/sweep_categorical_c1024_diffusion_finalsample_20260506.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import importlib.util
5
+ import sys
6
+ from pathlib import Path
7
+
8
+
9
+ BASE_PATH = Path(__file__).with_name("sweep_categorical_c1024_rolling_noise_20260506.py")
10
+ spec = importlib.util.spec_from_file_location("sweep_categorical_c1024_rolling_noise_20260506", BASE_PATH)
11
+ if spec is None or spec.loader is None:
12
+ raise RuntimeError(f"Could not load {BASE_PATH}")
13
+ sweep = importlib.util.module_from_spec(spec)
14
+ sys.modules[spec.name] = sweep
15
+ spec.loader.exec_module(sweep)
16
+
17
+
18
+ def finalsample_configs():
19
+ C = sweep.DecodeCfg
20
+ return [
21
+ C("fs_t1p3_eta1_blend_ft0p35", endpoint_temp=1.3, final_sample=True, final_temp=0.35),
22
+ C("fs_t1p3_eta1_blend_ft0p50", endpoint_temp=1.3, final_sample=True, final_temp=0.50),
23
+ C("fs_t1p3_eta1_blend_ft0p70", endpoint_temp=1.3, final_sample=True, final_temp=0.70),
24
+ C("fs_t1p3_eta1_blend_ft1p00", endpoint_temp=1.3, final_sample=True, final_temp=1.00),
25
+ C("fs_t1p3_eta1_blend_ft1p30", endpoint_temp=1.3, final_sample=True, final_temp=1.30),
26
+ C("fs_t1p3_eta1_blend_ft1p60", endpoint_temp=1.3, final_sample=True, final_temp=1.60),
27
+ C("fs_t1p2_eta1_blend_ft0p50", endpoint_temp=1.2, final_sample=True, final_temp=0.50),
28
+ C("fs_t1p2_eta1_blend_ft0p70", endpoint_temp=1.2, final_sample=True, final_temp=0.70),
29
+ C("fs_t1p2_eta1_blend_ft1p00", endpoint_temp=1.2, final_sample=True, final_temp=1.00),
30
+ C("fs_t1p1_eta1_blend_ft0p70", endpoint_temp=1.1, final_sample=True, final_temp=0.70),
31
+ C("fs_t1p3_eta1_state_ft0p35", endpoint_temp=1.3, final_from="state", final_sample=True, final_temp=0.35),
32
+ C("fs_t1p3_eta1_state_ft0p50", endpoint_temp=1.3, final_from="state", final_sample=True, final_temp=0.50),
33
+ C(
34
+ "fs_sp1p15_sem1p7_t1p3_eta1_blend_ft0p50",
35
+ endpoint_temp=1.3,
36
+ support_power=1.15,
37
+ semantic_power=1.7,
38
+ final_sample=True,
39
+ final_temp=0.50,
40
+ ),
41
+ ]
42
+
43
+
44
+ sweep.default_configs = finalsample_configs
45
+
46
+
47
+ if __name__ == "__main__":
48
+ sweep.main()