Add files using upload-large-folder tool
Browse files- LTA_openwebtext_dualt/logs/train8_len_sweep_compact_bs512_until_exact_4gpu/nohup.log +0 -0
- LTA_openwebtext_dualt/scripts/eval_ar_8gpu_latest_temp_sweep_20260506.py +28 -0
- LTA_openwebtext_dualt/scripts/eval_dirichlet_latest_key3_20260508.py +51 -0
- LTA_openwebtext_dualt/scripts/eval_lm1b_c1024_fullycoupled_8gpu_1m_sde.sh +39 -0
- LTA_openwebtext_dualt/scripts/extract_mauve_features_lm1b.py +110 -0
- LTA_openwebtext_dualt/scripts/launch_lta_lm1b_fullycoupled_t5_logitnormal_8gpu.sh +6 -0
- LTA_openwebtext_dualt/scripts/launch_lta_owt_compact_gpt2bpe_v2048_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu.sh +250 -0
- LTA_openwebtext_dualt/scripts/launch_lta_owt_fullycoupled_elftsampling_mask1_wd0p1_fp32_8gpu.sh +202 -0
- LTA_openwebtext_dualt/scripts/launch_lta_owt_fullycoupled_uniform_mask1_swiglu_wd0p1_fp32_8gpu.sh +204 -0
- LTA_openwebtext_dualt/scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh +271 -0
- LTA_openwebtext_dualt/scripts/launch_lta_owt_selfcond_p05_4gpu.sh +65 -0
- LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_blockar_twostream_8gpu.sh +120 -0
- LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_rollin_p50_randk0_4_8gpu.sh +60 -0
- LTA_openwebtext_dualt/scripts/make_compact_gpt2_bpe_tokenizer.py +160 -0
- LTA_openwebtext_dualt/scripts/run_lta_lm1b_linear_simplex_len128_4gpu_save1k_with_watch.sh +217 -0
- LTA_openwebtext_dualt/scripts/run_lta_owt_t5elf_absrope_time4_dirichlet_len1025_C1_to_1024_8gpu_1m_mask1_sameT_prebos_save10k.sh +74 -0
- LTA_openwebtext_dualt/scripts/run_train8_ctx1024_rollin_sweep_4gpu.sh +235 -0
- LTA_openwebtext_dualt/scripts/run_train8_len256_headwd_sweep_bs512_until_exact_4gpu.sh +287 -0
- LTA_openwebtext_dualt/scripts/run_train8_noise_geometry_pilots_4gpu.sh +289 -0
- LTA_openwebtext_dualt/scripts/sweep_categorical_c1024_diffusion_finalsample_20260506.py +48 -0
LTA_openwebtext_dualt/logs/train8_len_sweep_compact_bs512_until_exact_4gpu/nohup.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LTA_openwebtext_dualt/scripts/eval_ar_8gpu_latest_temp_sweep_20260506.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import importlib.util
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
BASE_PATH = Path(__file__).with_name("eval_lm1b_200k_methods_genppl_20260506.py")
|
| 10 |
+
spec = importlib.util.spec_from_file_location("eval_lm1b_200k_methods_genppl_20260506", BASE_PATH)
|
| 11 |
+
if spec is None or spec.loader is None:
|
| 12 |
+
raise RuntimeError(f"Could not load {BASE_PATH}")
|
| 13 |
+
base = importlib.util.module_from_spec(spec)
|
| 14 |
+
sys.modules[spec.name] = base
|
| 15 |
+
spec.loader.exec_module(base)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
base.TARGETS = [
|
| 19 |
+
base.EvalTarget(
|
| 20 |
+
"ar_8gpu_scratch_latest",
|
| 21 |
+
"ar",
|
| 22 |
+
"runs/ar_lm1b_flmpack_bert_small_len128_gbs512_8gpu_1m_scratch_20260505/latest.pt",
|
| 23 |
+
)
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
base.main()
|
LTA_openwebtext_dualt/scripts/eval_dirichlet_latest_key3_20260508.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import importlib.util
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
BASE = Path(__file__).with_name("eval_c1024_decode_sweep_20260507.py")
|
| 10 |
+
spec = importlib.util.spec_from_file_location("eval_c1024_decode_sweep_20260507", BASE)
|
| 11 |
+
if spec is None or spec.loader is None:
|
| 12 |
+
raise RuntimeError(f"cannot import {BASE}")
|
| 13 |
+
base = importlib.util.module_from_spec(spec)
|
| 14 |
+
sys.modules[spec.name] = base
|
| 15 |
+
spec.loader.exec_module(base)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def key_configs() -> list[base.DecodeConfig]:
|
| 19 |
+
return [
|
| 20 |
+
base.DecodeConfig(
|
| 21 |
+
"match_post_sem1_blend_c16_t1p3",
|
| 22 |
+
"post",
|
| 23 |
+
1.0,
|
| 24 |
+
1.0,
|
| 25 |
+
"blend",
|
| 26 |
+
endpoint_temp=1.3,
|
| 27 |
+
concentration_max=16.0,
|
| 28 |
+
),
|
| 29 |
+
base.DecodeConfig(
|
| 30 |
+
"match_post_sem1_blend_c64_t1p3",
|
| 31 |
+
"post",
|
| 32 |
+
1.0,
|
| 33 |
+
1.0,
|
| 34 |
+
"blend",
|
| 35 |
+
endpoint_temp=1.3,
|
| 36 |
+
concentration_max=64.0,
|
| 37 |
+
),
|
| 38 |
+
base.DecodeConfig(
|
| 39 |
+
"match_post_sem1_blend_c1024_t1p3",
|
| 40 |
+
"post",
|
| 41 |
+
1.0,
|
| 42 |
+
1.0,
|
| 43 |
+
"blend",
|
| 44 |
+
endpoint_temp=1.3,
|
| 45 |
+
concentration_max=1024.0,
|
| 46 |
+
),
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
base.default_configs = key_configs
|
| 51 |
+
base.main()
|
LTA_openwebtext_dualt/scripts/eval_lm1b_c1024_fullycoupled_8gpu_1m_sde.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# SDE (Dirichlet resample) GenPPL eval for:
|
| 3 |
+
# lta_lm1b_dirichlet_categorical_fullvocab_c1024_fullycoupled_flmpack_onehot_hardce_ddit_small_len128_gbs512_8gpu_1m_nw0
|
| 4 |
+
set -euo pipefail
|
| 5 |
+
|
| 6 |
+
ROOT="/e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt"
|
| 7 |
+
cd "${ROOT}"
|
| 8 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 9 |
+
export TOKENIZERS_PARALLELISM=false
|
| 10 |
+
|
| 11 |
+
RUN_DIR="${RUN_DIR:-runs/lta_lm1b_dirichlet_categorical_fullvocab_c1024_fullycoupled_flmpack_onehot_hardce_ddit_small_len128_gbs512_8gpu_1m_nw0}"
|
| 12 |
+
CHECKPOINT="${CHECKPOINT:-${RUN_DIR}/latest.pt}"
|
| 13 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/workspace/imagenet_handoff_20260327/nlp_dts_light/assets/distilbert-base-uncased/tokenizer.json}"
|
| 14 |
+
SCORER="${SCORER:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-large-standard}"
|
| 15 |
+
|
| 16 |
+
STEPS="${STEPS:-1024}"
|
| 17 |
+
N_SAMPLES="${N_SAMPLES:-128}"
|
| 18 |
+
MAX_LEN="${MAX_LEN:-128}"
|
| 19 |
+
BATCH_SIZE="${BATCH_SIZE:-16}"
|
| 20 |
+
SCORE_BATCH="${SCORE_BATCH:-8}"
|
| 21 |
+
|
| 22 |
+
OUT_DIR="${OUT_DIR:-${RUN_DIR}/eval_sde_steps${STEPS}_n${N_SAMPLES}}"
|
| 23 |
+
|
| 24 |
+
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
|
| 25 |
+
|
| 26 |
+
echo "[eval-sde] checkpoint=${CHECKPOINT}"
|
| 27 |
+
echo "[eval-sde] steps=${STEPS} n_samples=${N_SAMPLES} max_len=${MAX_LEN}"
|
| 28 |
+
echo "[eval-sde] out_dir=${OUT_DIR}"
|
| 29 |
+
|
| 30 |
+
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES}" python scripts/eval_lm1b_c1024_fullycoupled_sde_genppl.py \
|
| 31 |
+
--checkpoint "${CHECKPOINT}" \
|
| 32 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 33 |
+
--scorer "${SCORER}" \
|
| 34 |
+
--out_dir "${OUT_DIR}" \
|
| 35 |
+
--steps "${STEPS}" \
|
| 36 |
+
--n_samples "${N_SAMPLES}" \
|
| 37 |
+
--max_len "${MAX_LEN}" \
|
| 38 |
+
--batch_size "${BATCH_SIZE}" \
|
| 39 |
+
--score_batch "${SCORE_BATCH}"
|
LTA_openwebtext_dualt/scripts/extract_mauve_features_lm1b.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import AutoModel, AutoTokenizer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
IN_PATH = Path("runs/decode_lab/mauve_export_lm1b_latest_1024.jsonl")
|
| 13 |
+
OUT_PATH = Path("runs/decode_lab/mauve_features_lm1b_latest_1024_gpt2large.npz")
|
| 14 |
+
MODEL_PATH = "/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-large-standard"
|
| 15 |
+
MAX_LEN = 256
|
| 16 |
+
BATCH_SIZE = 16
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_texts() -> tuple[list[str], list[str], dict[str, list[str]], dict[str, list[str]]]:
|
| 20 |
+
refs_raw: list[str] = []
|
| 21 |
+
refs_stripped: list[str] = []
|
| 22 |
+
gen_raw: dict[str, list[str]] = {}
|
| 23 |
+
gen_stripped: dict[str, list[str]] = {}
|
| 24 |
+
with IN_PATH.open(encoding="utf-8") as f:
|
| 25 |
+
for line in f:
|
| 26 |
+
obj = json.loads(line)
|
| 27 |
+
typ = obj.get("type")
|
| 28 |
+
if typ == "reference":
|
| 29 |
+
refs_raw.append(obj["raw_text"])
|
| 30 |
+
refs_stripped.append(obj["stripped_text"])
|
| 31 |
+
elif typ == "generated":
|
| 32 |
+
setting = obj["setting"]
|
| 33 |
+
key = f"t{float(setting['endpoint_temp']):.2f}_{setting['final_from']}"
|
| 34 |
+
gen_raw.setdefault(key, []).append(obj["raw_text"])
|
| 35 |
+
gen_stripped.setdefault(key, []).append(obj["stripped_text"])
|
| 36 |
+
return refs_raw, refs_stripped, gen_raw, gen_stripped
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def featurize(
|
| 41 |
+
texts: list[str],
|
| 42 |
+
name: str,
|
| 43 |
+
tokenizer: AutoTokenizer,
|
| 44 |
+
model: AutoModel,
|
| 45 |
+
device: torch.device,
|
| 46 |
+
) -> np.ndarray:
|
| 47 |
+
chunks: list[np.ndarray] = []
|
| 48 |
+
t0 = time.time()
|
| 49 |
+
for start in range(0, len(texts), BATCH_SIZE):
|
| 50 |
+
batch = texts[start : start + BATCH_SIZE]
|
| 51 |
+
enc = tokenizer(
|
| 52 |
+
batch,
|
| 53 |
+
return_tensors="pt",
|
| 54 |
+
padding=True,
|
| 55 |
+
truncation=True,
|
| 56 |
+
max_length=MAX_LEN,
|
| 57 |
+
return_attention_mask=True,
|
| 58 |
+
).to(device)
|
| 59 |
+
out = model(
|
| 60 |
+
input_ids=enc["input_ids"],
|
| 61 |
+
attention_mask=enc["attention_mask"],
|
| 62 |
+
output_hidden_states=True,
|
| 63 |
+
return_dict=True,
|
| 64 |
+
)
|
| 65 |
+
hidden = out.hidden_states[-1]
|
| 66 |
+
last_idx = enc["attention_mask"].sum(dim=1) - 1
|
| 67 |
+
feat = hidden[torch.arange(hidden.size(0), device=device), last_idx]
|
| 68 |
+
chunks.append(feat.float().cpu().numpy())
|
| 69 |
+
if (start // BATCH_SIZE) % 10 == 0:
|
| 70 |
+
print(f"{name} {start + len(batch)}/{len(texts)}", flush=True)
|
| 71 |
+
arr = np.concatenate(chunks, axis=0)
|
| 72 |
+
print(f"{name} {arr.shape} time={time.time() - t0:.1f}s", flush=True)
|
| 73 |
+
return arr
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main() -> None:
|
| 77 |
+
refs_raw, refs_stripped, gen_raw, gen_stripped = load_texts()
|
| 78 |
+
print(
|
| 79 |
+
"loaded",
|
| 80 |
+
len(refs_raw),
|
| 81 |
+
{key: len(value) for key, value in gen_raw.items()},
|
| 82 |
+
flush=True,
|
| 83 |
+
)
|
| 84 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 85 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
| 86 |
+
if tokenizer.pad_token_id is None:
|
| 87 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 88 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 89 |
+
model = AutoModel.from_pretrained(MODEL_PATH, pad_token_id=tokenizer.eos_token_id).to(device).eval()
|
| 90 |
+
|
| 91 |
+
arrays: dict[str, np.ndarray] = {
|
| 92 |
+
"ref_raw": featurize(refs_raw, "ref_raw", tokenizer, model, device),
|
| 93 |
+
"ref_stripped": featurize(refs_stripped, "ref_stripped", tokenizer, model, device),
|
| 94 |
+
}
|
| 95 |
+
for key in sorted(gen_raw):
|
| 96 |
+
arrays[f"gen_{key}_raw"] = featurize(gen_raw[key], f"gen_{key}_raw", tokenizer, model, device)
|
| 97 |
+
arrays[f"gen_{key}_stripped"] = featurize(
|
| 98 |
+
gen_stripped[key],
|
| 99 |
+
f"gen_{key}_stripped",
|
| 100 |
+
tokenizer,
|
| 101 |
+
model,
|
| 102 |
+
device,
|
| 103 |
+
)
|
| 104 |
+
OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 105 |
+
np.savez_compressed(OUT_PATH, **arrays)
|
| 106 |
+
print(f"DONE {OUT_PATH}", flush=True)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
main()
|
LTA_openwebtext_dualt/scripts/launch_lta_lm1b_fullycoupled_t5_logitnormal_8gpu.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
exec bash scripts/launch_lta_lm1b_elfaligned_t5_logitnormal_8gpu.sh "$@"
|
LTA_openwebtext_dualt/scripts/launch_lta_owt_compact_gpt2bpe_v2048_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu.sh
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
|
| 7 |
+
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
|
| 8 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 9 |
+
export TOKENIZERS_PARALLELISM=false
|
| 10 |
+
export PYTHONUNBUFFERED=1
|
| 11 |
+
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
|
| 12 |
+
export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
|
| 13 |
+
|
| 14 |
+
# Compact GPT2-BPE v2048, rebuilt as tokenizer-stream chunks:
|
| 15 |
+
# every example is length 1024 except the final tail record in each part.
|
| 16 |
+
COMPACT_VARIANT_LABEL="${COMPACT_VARIANT_LABEL:-compact_gpt2bpe_v2048_stream1024_fullycoupled_mask1_wd0p1_fp32}"
|
| 17 |
+
VOCAB_SIZE="${VOCAB_SIZE:-2048}"
|
| 18 |
+
DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-compact-gpt2bpe-v2048-stream1024-train-minus-100k}"
|
| 19 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/lta_tokenizers/owt_compact_gpt2bpe_v2048/tokenizer.json}"
|
| 20 |
+
|
| 21 |
+
NNODES="${NNODES:-1}"
|
| 22 |
+
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
| 23 |
+
NODE_RANK="${NODE_RANK:-0}"
|
| 24 |
+
MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
|
| 25 |
+
MASTER_PORT="${MASTER_PORT:-32249}"
|
| 26 |
+
|
| 27 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
|
| 28 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 29 |
+
TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
|
| 30 |
+
WARMUP_STEPS="${WARMUP_STEPS:-2000}"
|
| 31 |
+
NUM_WORKERS="${NUM_WORKERS:-8}"
|
| 32 |
+
DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
|
| 33 |
+
LOG_EVERY="${LOG_EVERY:-50}"
|
| 34 |
+
LATEST_EVERY="${LATEST_EVERY:-1000}"
|
| 35 |
+
EVAL_EVERY="${EVAL_EVERY:-0}"
|
| 36 |
+
ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
|
| 37 |
+
ALLOW_TF32="${ALLOW_TF32:-1}"
|
| 38 |
+
DRY_RUN="${DRY_RUN:-0}"
|
| 39 |
+
RESUME_PATH="${RESUME_PATH:-}"
|
| 40 |
+
|
| 41 |
+
# Fully-coupled fp32 baseline recipe: AdamW, nanoGPT decay groups, RMSNorm,
|
| 42 |
+
# no output bias, hard CE, Dirichlet bridge, full-vocab wrong token.
|
| 43 |
+
LR="${LR:-6e-4}"
|
| 44 |
+
MIN_LR="${MIN_LR:-6e-5}"
|
| 45 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
|
| 46 |
+
OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}"
|
| 47 |
+
ADAM_BETA1="${ADAM_BETA1:-0.9}"
|
| 48 |
+
ADAM_BETA2="${ADAM_BETA2:-0.95}"
|
| 49 |
+
ADAM_EPS="${ADAM_EPS:-1e-8}"
|
| 50 |
+
GRAD_CLIP="${GRAD_CLIP:-1.0}"
|
| 51 |
+
EMA_DECAY="${EMA_DECAY:-0.0}"
|
| 52 |
+
EMA_START_STEP="${EMA_START_STEP:-0}"
|
| 53 |
+
LOSS_T_WEIGHT_MODE="${LOSS_T_WEIGHT_MODE:-none}"
|
| 54 |
+
LOSS_T_MIN_WEIGHT="${LOSS_T_MIN_WEIGHT:-0.0}"
|
| 55 |
+
LOSS_T_DROP_BELOW="${LOSS_T_DROP_BELOW:-0.2}"
|
| 56 |
+
|
| 57 |
+
T_SAMPLING_MODE="${T_SAMPLING_MODE:-logit_normal}"
|
| 58 |
+
T_SAMPLING_POWER="${T_SAMPLING_POWER:-1.0}"
|
| 59 |
+
T_SAMPLING_EPS="${T_SAMPLING_EPS:-1e-4}"
|
| 60 |
+
T_SAMPLING_LOGIT_MEAN="${T_SAMPLING_LOGIT_MEAN:--1.5}"
|
| 61 |
+
T_SAMPLING_LOGIT_STD="${T_SAMPLING_LOGIT_STD:-0.8}"
|
| 62 |
+
MIN_MASK_RATIO="${MIN_MASK_RATIO:-1.0}"
|
| 63 |
+
MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
|
| 64 |
+
|
| 65 |
+
sanitize_label() {
|
| 66 |
+
printf "%s" "$1" | sed -e 's/-/m/g' -e 's/\./p/g'
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
T_LOGIT_MEAN_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_MEAN}")"
|
| 70 |
+
T_LOGIT_STD_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_STD}")"
|
| 71 |
+
MIN_MASK_RATIO_LABEL="$(sanitize_label "${MIN_MASK_RATIO}")"
|
| 72 |
+
MAX_MASK_RATIO_LABEL="$(sanitize_label "${MAX_MASK_RATIO}")"
|
| 73 |
+
|
| 74 |
+
RUN_NAME="${RUN_NAME:-lta_owt_compact_gpt2bpe_v2048_stream1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_logitnormal_${T_LOGIT_MEAN_LABEL}_s${T_LOGIT_STD_LABEL}_hardce_mask${MIN_MASK_RATIO_LABEL}-${MAX_MASK_RATIO_LABEL}_fp32_ddit768x12_gbs512_8gpu_1m_$(date +%Y%m%d_%H%M%S)}"
|
| 75 |
+
SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
|
| 76 |
+
LOG_DIR="${LOG_DIR:-logs/compact_gpt2bpe_v2048_stream1024_fullycoupled_mask1_wd0p1_fp32_8gpu}"
|
| 77 |
+
LOG_FILE="${LOG_FILE:-${LOG_DIR}/${RUN_NAME}.log}"
|
| 78 |
+
|
| 79 |
+
if [[ -n "${RESUME_PATH}" ]]; then
|
| 80 |
+
ALLOW_EXISTING_SAVE_DIR=1
|
| 81 |
+
if [[ ! -f "${RESUME_PATH}" ]]; then
|
| 82 |
+
echo "Missing resume checkpoint: ${RESUME_PATH}" >&2
|
| 83 |
+
exit 2
|
| 84 |
+
fi
|
| 85 |
+
fi
|
| 86 |
+
|
| 87 |
+
if [[ ! -f "${TOKENIZER_PATH}" ]]; then
|
| 88 |
+
echo "Missing tokenizer: ${TOKENIZER_PATH}" >&2
|
| 89 |
+
exit 2
|
| 90 |
+
fi
|
| 91 |
+
if [[ ! -d "${DATA_PATH}" ]]; then
|
| 92 |
+
echo "Missing tokenized dataset: ${DATA_PATH}" >&2
|
| 93 |
+
echo "Build it with: bash scripts/build_lta_owt_compact_gpt2bpe_stream1024_train_minus_100k_np8.sh" >&2
|
| 94 |
+
exit 2
|
| 95 |
+
fi
|
| 96 |
+
|
| 97 |
+
NUM_EXAMPLES=$(python - <<PY
|
| 98 |
+
import json
|
| 99 |
+
from pathlib import Path
|
| 100 |
+
from datasets import Sequence, load_from_disk
|
| 101 |
+
from datasets.features import features as hf_features
|
| 102 |
+
hf_features._FEATURE_TYPES.setdefault("List", Sequence)
|
| 103 |
+
root = Path("${DATA_PATH}")
|
| 104 |
+
for name in ("elf_multi_part_meta.json", "elf_build_meta.json"):
|
| 105 |
+
path = root / name
|
| 106 |
+
if path.exists():
|
| 107 |
+
print(int(json.loads(path.read_text()).get("num_examples", 0)))
|
| 108 |
+
raise SystemExit
|
| 109 |
+
parts_root = root / "parts"
|
| 110 |
+
if parts_root.is_dir():
|
| 111 |
+
print(sum(len(load_from_disk(str(p))) for p in sorted(parts_root.iterdir()) if p.is_dir()))
|
| 112 |
+
else:
|
| 113 |
+
print(len(load_from_disk(str(root))))
|
| 114 |
+
PY
|
| 115 |
+
)
|
| 116 |
+
WORLD_SIZE=$(( NNODES * NPROC_PER_NODE ))
|
| 117 |
+
GRAD_ACCUM=$(( (GLOBAL_BATCH_SIZE + PER_GPU_BATCH_SIZE * WORLD_SIZE - 1) / (PER_GPU_BATCH_SIZE * WORLD_SIZE) ))
|
| 118 |
+
SAMPLES_PER_RANK=$(( (NUM_EXAMPLES + WORLD_SIZE - 1) / WORLD_SIZE ))
|
| 119 |
+
LOADER_BATCHES_PER_RANK=$(( SAMPLES_PER_RANK / PER_GPU_BATCH_SIZE ))
|
| 120 |
+
STEPS_PER_EPOCH=$(( (LOADER_BATCHES_PER_RANK + GRAD_ACCUM - 1) / GRAD_ACCUM ))
|
| 121 |
+
SAVE_EVERY="${SAVE_EVERY:-${STEPS_PER_EPOCH}}"
|
| 122 |
+
|
| 123 |
+
if [[ -f "${SAVE_DIR}/args.json" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
|
| 124 |
+
echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
|
| 125 |
+
echo "Use a new RUN_NAME/SAVE_DIR or set ALLOW_EXISTING_SAVE_DIR=1 intentionally." >&2
|
| 126 |
+
exit 2
|
| 127 |
+
fi
|
| 128 |
+
|
| 129 |
+
mkdir -p "${LOG_DIR}" "${SAVE_DIR}"
|
| 130 |
+
|
| 131 |
+
TF32_FLAG="--allow_tf32"
|
| 132 |
+
TF32_LABEL="true"
|
| 133 |
+
if [[ "${ALLOW_TF32}" == "0" || "${ALLOW_TF32}" == "false" || "${ALLOW_TF32}" == "False" ]]; then
|
| 134 |
+
TF32_FLAG="--no-allow_tf32"
|
| 135 |
+
TF32_LABEL="false"
|
| 136 |
+
fi
|
| 137 |
+
|
| 138 |
+
echo "[launch] method=${COMPACT_VARIANT_LABEL} host=$(hostname) time=$(date -Iseconds)"
|
| 139 |
+
echo "[launch] run_name=${RUN_NAME}"
|
| 140 |
+
echo "[launch] save_dir=${SAVE_DIR}"
|
| 141 |
+
echo "[launch] log_file=${LOG_FILE}"
|
| 142 |
+
echo "[launch] data_path=${DATA_PATH}"
|
| 143 |
+
echo "[launch] tokenizer=${TOKENIZER_PATH}"
|
| 144 |
+
echo "[launch] examples=${NUM_EXAMPLES} total_steps=${TOTAL_STEPS} world_size=${WORLD_SIZE} grad_accum=${GRAD_ACCUM} steps_per_epoch=${STEPS_PER_EPOCH} save_every=${SAVE_EVERY}"
|
| 145 |
+
echo "[launch] optimizer=adamw lr=${LR} min_lr=${MIN_LR} wd=${WEIGHT_DECAY} output_wd=${OUTPUT_WEIGHT_DECAY} param_groups=nanogpt ema=${EMA_DECAY}"
|
| 146 |
+
echo "[launch] fp32=true bf16=false tf32=${TF32_LABEL} norm_type=rmsnorm output_bias=false batch=${GLOBAL_BATCH_SIZE} per_gpu=${PER_GPU_BATCH_SIZE}"
|
| 147 |
+
echo "[launch] target_loss=hard_ce t_sampling_mode=${T_SAMPLING_MODE} t_sampling_logit_mean=${T_SAMPLING_LOGIT_MEAN} t_sampling_logit_std=${T_SAMPLING_LOGIT_STD} mask_ratio=${MIN_MASK_RATIO}->${MAX_MASK_RATIO}"
|
| 148 |
+
if [[ -n "${RESUME_PATH}" ]]; then
|
| 149 |
+
echo "[launch] resume_path=${RESUME_PATH}"
|
| 150 |
+
fi
|
| 151 |
+
|
| 152 |
+
if [[ "${DRY_RUN}" == "1" || "${DRY_RUN}" == "true" || "${DRY_RUN}" == "True" ]]; then
|
| 153 |
+
echo "[launch] DRY_RUN=1, validated tokenizer/data/schedule; skipping torchrun."
|
| 154 |
+
exit 0
|
| 155 |
+
fi
|
| 156 |
+
|
| 157 |
+
RESUME_ARGS=()
|
| 158 |
+
if [[ -n "${RESUME_PATH}" ]]; then
|
| 159 |
+
RESUME_ARGS+=(--resume_path "${RESUME_PATH}")
|
| 160 |
+
fi
|
| 161 |
+
|
| 162 |
+
python -m torch.distributed.run \
|
| 163 |
+
--nnodes="${NNODES}" \
|
| 164 |
+
--nproc_per_node="${NPROC_PER_NODE}" \
|
| 165 |
+
--node_rank="${NODE_RANK}" \
|
| 166 |
+
--master_addr="${MASTER_ADDR}" \
|
| 167 |
+
--master_port="${MASTER_PORT}" \
|
| 168 |
+
train.py \
|
| 169 |
+
"${RESUME_ARGS[@]}" \
|
| 170 |
+
--data_path "${DATA_PATH}" \
|
| 171 |
+
--tokenized_hf \
|
| 172 |
+
--tokenized_pad_token pad \
|
| 173 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 174 |
+
--save_dir "${SAVE_DIR}" \
|
| 175 |
+
--max_len 1024 \
|
| 176 |
+
--batch_size "${PER_GPU_BATCH_SIZE}" \
|
| 177 |
+
--global_batch_size "${GLOBAL_BATCH_SIZE}" \
|
| 178 |
+
--num_workers "${NUM_WORKERS}" \
|
| 179 |
+
--dataloader_prefetch_factor "${DATALOADER_PREFETCH_FACTOR}" \
|
| 180 |
+
--epochs 0 \
|
| 181 |
+
--total_steps "${TOTAL_STEPS}" \
|
| 182 |
+
--warmup_steps "${WARMUP_STEPS}" \
|
| 183 |
+
--log_every "${LOG_EVERY}" \
|
| 184 |
+
--eval_every "${EVAL_EVERY}" \
|
| 185 |
+
--save_every "${SAVE_EVERY}" \
|
| 186 |
+
--latest_every "${LATEST_EVERY}" \
|
| 187 |
+
--optimizer adamw \
|
| 188 |
+
--lr "${LR}" \
|
| 189 |
+
--lr_schedule cosine \
|
| 190 |
+
--min_lr "${MIN_LR}" \
|
| 191 |
+
--weight_decay "${WEIGHT_DECAY}" \
|
| 192 |
+
--output_weight_decay "${OUTPUT_WEIGHT_DECAY}" \
|
| 193 |
+
--adamw_param_groups nanogpt \
|
| 194 |
+
--adam_beta1 "${ADAM_BETA1}" \
|
| 195 |
+
--adam_beta2 "${ADAM_BETA2}" \
|
| 196 |
+
--adam_eps "${ADAM_EPS}" \
|
| 197 |
+
--ema_decay "${EMA_DECAY}" \
|
| 198 |
+
--ema_start_step "${EMA_START_STEP}" \
|
| 199 |
+
--grad_clip "${GRAD_CLIP}" \
|
| 200 |
+
--seed 123 \
|
| 201 |
+
--d_model 768 \
|
| 202 |
+
--cond_dim 128 \
|
| 203 |
+
--n_layers 12 \
|
| 204 |
+
--n_heads 12 \
|
| 205 |
+
--dim_ff 3072 \
|
| 206 |
+
--dropout 0.0 \
|
| 207 |
+
--no-output_bias \
|
| 208 |
+
--norm_type rmsnorm \
|
| 209 |
+
--model_type ddit \
|
| 210 |
+
--state_format prob \
|
| 211 |
+
--bridge dirichlet \
|
| 212 |
+
--target_loss hard_ce \
|
| 213 |
+
--loss_t_weight_mode "${LOSS_T_WEIGHT_MODE}" \
|
| 214 |
+
--loss_t_min_weight "${LOSS_T_MIN_WEIGHT}" \
|
| 215 |
+
--loss_t_drop_below "${LOSS_T_DROP_BELOW}" \
|
| 216 |
+
--target_prob 1.0 \
|
| 217 |
+
--min_t 0.0 \
|
| 218 |
+
--max_t 1.0 \
|
| 219 |
+
--t_sampling_mode "${T_SAMPLING_MODE}" \
|
| 220 |
+
--t_sampling_power "${T_SAMPLING_POWER}" \
|
| 221 |
+
--t_sampling_logit_mean "${T_SAMPLING_LOGIT_MEAN}" \
|
| 222 |
+
--t_sampling_logit_std "${T_SAMPLING_LOGIT_STD}" \
|
| 223 |
+
--t_sampling_eps "${T_SAMPLING_EPS}" \
|
| 224 |
+
--dual_t \
|
| 225 |
+
--corrupt_t_mode same \
|
| 226 |
+
--corrupt_min_t 0.0 \
|
| 227 |
+
--corrupt_max_t 1.0 \
|
| 228 |
+
--min_mask_ratio "${MIN_MASK_RATIO}" \
|
| 229 |
+
--max_mask_ratio "${MAX_MASK_RATIO}" \
|
| 230 |
+
--wrong_token_replace_prob 1.0 \
|
| 231 |
+
--wrong_token_schedule linear_t \
|
| 232 |
+
--wrong_token_exp_k 1.0 \
|
| 233 |
+
--dirichlet_concentration_min 1.0 \
|
| 234 |
+
--dirichlet_concentration_max 1024 \
|
| 235 |
+
--dirichlet_endpoint_mode categorical_dual_t \
|
| 236 |
+
--dirichlet_semantic_t_mode same \
|
| 237 |
+
--dirichlet_semantic_t_value 0.0 \
|
| 238 |
+
--categorical_wrong_from_full_vocab \
|
| 239 |
+
--simplex_bridge_sampler dirichlet \
|
| 240 |
+
--eps 1e-8 \
|
| 241 |
+
--infer_steps 1024 \
|
| 242 |
+
--decode_damping 1.0 \
|
| 243 |
+
--max_gamma 1.0 \
|
| 244 |
+
--decode_solver flowmap \
|
| 245 |
+
--noise_init logistic_normal \
|
| 246 |
+
--bridge_noise_init logistic_normal \
|
| 247 |
+
--noise_sigma -1 \
|
| 248 |
+
"${TF32_FLAG}" \
|
| 249 |
+
--ddp_gradient_as_bucket_view \
|
| 250 |
+
2>&1 | tee -a "${LOG_FILE}"
|
LTA_openwebtext_dualt/scripts/launch_lta_owt_fullycoupled_elftsampling_mask1_wd0p1_fp32_8gpu.sh
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
|
| 7 |
+
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
|
| 8 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 9 |
+
export TOKENIZERS_PARALLELISM=false
|
| 10 |
+
export PYTHONUNBUFFERED=1
|
| 11 |
+
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
|
| 12 |
+
export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
|
| 13 |
+
|
| 14 |
+
# Fully-coupled OWT baseline with FP32 params/activations and nanoGPT-style AdamW decay:
|
| 15 |
+
# decay: matrix / embedding params (p.dim() >= 2)
|
| 16 |
+
# no decay: bias / norm / 1D params
|
| 17 |
+
# This intentionally does NOT pass --bf16. TF32 is enabled by default so H200 uses
|
| 18 |
+
# Tensor Cores; set ALLOW_TF32=0 for strict FP32 debugging.
|
| 19 |
+
|
| 20 |
+
T_SAMPLING_MODE="${T_SAMPLING_MODE:-logit_normal}"
|
| 21 |
+
T_SAMPLING_POWER="${T_SAMPLING_POWER:-1.0}"
|
| 22 |
+
T_SAMPLING_EPS="${T_SAMPLING_EPS:-1e-4}"
|
| 23 |
+
T_SAMPLING_LOGIT_MEAN="${T_SAMPLING_LOGIT_MEAN:--1.5}"
|
| 24 |
+
T_SAMPLING_LOGIT_STD="${T_SAMPLING_LOGIT_STD:-0.8}"
|
| 25 |
+
MIN_MASK_RATIO="${MIN_MASK_RATIO:-1.0}"
|
| 26 |
+
MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
|
| 27 |
+
|
| 28 |
+
sanitize_label() {
|
| 29 |
+
printf "%s" "$1" | sed -e 's/-/m/g' -e 's/\./p/g'
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
T_LOGIT_MEAN_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_MEAN}")"
|
| 33 |
+
T_LOGIT_STD_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_STD}")"
|
| 34 |
+
MIN_MASK_RATIO_LABEL="$(sanitize_label "${MIN_MASK_RATIO}")"
|
| 35 |
+
MAX_MASK_RATIO_LABEL="$(sanitize_label "${MAX_MASK_RATIO}")"
|
| 36 |
+
|
| 37 |
+
RUN_NAME="${RUN_NAME:-lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_elftsampling_${T_LOGIT_MEAN_LABEL}_s${T_LOGIT_STD_LABEL}_hardce_mask${MIN_MASK_RATIO_LABEL}-${MAX_MASK_RATIO_LABEL}_nanogpt_fp32_ddit768x12_gbs512_8gpu_1m_$(date +%Y%m%d_%H%M%S)}"
|
| 38 |
+
SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
|
| 39 |
+
LOG_DIR="${LOG_DIR:-logs/fullycoupled_elftsampling_mask1_wd0p1_fp32_8gpu}"
|
| 40 |
+
LOG_FILE="${LOG_FILE:-${LOG_DIR}/${RUN_NAME}.log}"
|
| 41 |
+
|
| 42 |
+
DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext}"
|
| 43 |
+
OWT_CACHE="${OWT_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k}"
|
| 44 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
|
| 45 |
+
|
| 46 |
+
NNODES="${NNODES:-1}"
|
| 47 |
+
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
| 48 |
+
NODE_RANK="${NODE_RANK:-0}"
|
| 49 |
+
MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
|
| 50 |
+
MASTER_PORT="${MASTER_PORT:-31997}"
|
| 51 |
+
|
| 52 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
|
| 53 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 54 |
+
TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
|
| 55 |
+
WARMUP_STEPS="${WARMUP_STEPS:-2000}"
|
| 56 |
+
NUM_WORKERS="${NUM_WORKERS:-8}"
|
| 57 |
+
DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
|
| 58 |
+
LOG_EVERY="${LOG_EVERY:-50}"
|
| 59 |
+
SAVE_EVERY="${SAVE_EVERY:-50000}"
|
| 60 |
+
LATEST_EVERY="${LATEST_EVERY:-1000}"
|
| 61 |
+
EVAL_EVERY="${EVAL_EVERY:-0}"
|
| 62 |
+
ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
|
| 63 |
+
ALLOW_TF32="${ALLOW_TF32:-1}"
|
| 64 |
+
DRY_RUN="${DRY_RUN:-0}"
|
| 65 |
+
|
| 66 |
+
LR="${LR:-6e-4}"
|
| 67 |
+
MIN_LR="${MIN_LR:-6e-5}"
|
| 68 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
|
| 69 |
+
OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}"
|
| 70 |
+
ADAM_BETA1="${ADAM_BETA1:-0.9}"
|
| 71 |
+
ADAM_BETA2="${ADAM_BETA2:-0.95}"
|
| 72 |
+
ADAM_EPS="${ADAM_EPS:-1e-8}"
|
| 73 |
+
GRAD_CLIP="${GRAD_CLIP:-1.0}"
|
| 74 |
+
EMA_DECAY="${EMA_DECAY:-0.0}"
|
| 75 |
+
EMA_START_STEP="${EMA_START_STEP:-0}"
|
| 76 |
+
LOSS_T_WEIGHT_MODE="${LOSS_T_WEIGHT_MODE:-none}"
|
| 77 |
+
LOSS_T_MIN_WEIGHT="${LOSS_T_MIN_WEIGHT:-0.0}"
|
| 78 |
+
LOSS_T_DROP_BELOW="${LOSS_T_DROP_BELOW:-0.2}"
|
| 79 |
+
|
| 80 |
+
if [[ -f "${SAVE_DIR}/args.json" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
|
| 81 |
+
echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
|
| 82 |
+
echo "Use a new RUN_NAME/SAVE_DIR or set ALLOW_EXISTING_SAVE_DIR=1 intentionally." >&2
|
| 83 |
+
exit 2
|
| 84 |
+
fi
|
| 85 |
+
|
| 86 |
+
mkdir -p "${LOG_DIR}" "${SAVE_DIR}"
|
| 87 |
+
|
| 88 |
+
TF32_FLAG="--allow_tf32"
|
| 89 |
+
TF32_LABEL="true"
|
| 90 |
+
if [[ "${ALLOW_TF32}" == "0" || "${ALLOW_TF32}" == "false" || "${ALLOW_TF32}" == "False" ]]; then
|
| 91 |
+
TF32_FLAG="--no-allow_tf32"
|
| 92 |
+
TF32_LABEL="false"
|
| 93 |
+
fi
|
| 94 |
+
|
| 95 |
+
echo "[launch] method=owt_fullycoupled_adamw_wd0p1_nanogpt_fp32 host=$(hostname) time=$(date -Iseconds)"
|
| 96 |
+
echo "[launch] run_name=${RUN_NAME}"
|
| 97 |
+
echo "[launch] save_dir=${SAVE_DIR}"
|
| 98 |
+
echo "[launch] log_file=${LOG_FILE}"
|
| 99 |
+
echo "[launch] data_path=${DATA_PATH}"
|
| 100 |
+
echo "[launch] owt_cache=${OWT_CACHE}"
|
| 101 |
+
echo "[launch] optimizer=adamw lr=${LR} min_lr=${MIN_LR} wd=${WEIGHT_DECAY} output_wd=${OUTPUT_WEIGHT_DECAY} param_groups=nanogpt ema=${EMA_DECAY}"
|
| 102 |
+
echo "[launch] fp32=true bf16=false tf32=${TF32_LABEL} norm_type=rmsnorm output_bias=false batch=${GLOBAL_BATCH_SIZE} per_gpu=${PER_GPU_BATCH_SIZE}"
|
| 103 |
+
echo "[launch] loss_t_weight_mode=${LOSS_T_WEIGHT_MODE} loss_t_min_weight=${LOSS_T_MIN_WEIGHT} loss_t_drop_below=${LOSS_T_DROP_BELOW}"
|
| 104 |
+
echo "[launch] target_loss=hard_ce t_sampling_mode=${T_SAMPLING_MODE} t_sampling_logit_mean=${T_SAMPLING_LOGIT_MEAN} t_sampling_logit_std=${T_SAMPLING_LOGIT_STD} t_sampling_power=${T_SAMPLING_POWER} t_sampling_eps=${T_SAMPLING_EPS} mask_ratio=${MIN_MASK_RATIO}->${MAX_MASK_RATIO}"
|
| 105 |
+
|
| 106 |
+
if [[ "${DRY_RUN}" == "1" || "${DRY_RUN}" == "true" || "${DRY_RUN}" == "True" ]]; then
|
| 107 |
+
echo "[launch] DRY_RUN=1, validated launch parameters; skipping torchrun."
|
| 108 |
+
exit 0
|
| 109 |
+
fi
|
| 110 |
+
|
| 111 |
+
python -m torch.distributed.run \
|
| 112 |
+
--nnodes="${NNODES}" \
|
| 113 |
+
--nproc_per_node="${NPROC_PER_NODE}" \
|
| 114 |
+
--node_rank="${NODE_RANK}" \
|
| 115 |
+
--master_addr="${MASTER_ADDR}" \
|
| 116 |
+
--master_port="${MASTER_PORT}" \
|
| 117 |
+
train.py \
|
| 118 |
+
--data_path "${DATA_PATH}" \
|
| 119 |
+
--openwebtext_split train_minus_100k \
|
| 120 |
+
--text_column text \
|
| 121 |
+
--detokenizer auto \
|
| 122 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 123 |
+
--save_dir "${SAVE_DIR}" \
|
| 124 |
+
--wrap \
|
| 125 |
+
--wrap_mode stream \
|
| 126 |
+
--owt_cached_chunks \
|
| 127 |
+
--owt_chunk_cache_dir "${OWT_CACHE}" \
|
| 128 |
+
--max_len 1024 \
|
| 129 |
+
--batch_size "${PER_GPU_BATCH_SIZE}" \
|
| 130 |
+
--global_batch_size "${GLOBAL_BATCH_SIZE}" \
|
| 131 |
+
--num_workers "${NUM_WORKERS}" \
|
| 132 |
+
--dataloader_prefetch_factor "${DATALOADER_PREFETCH_FACTOR}" \
|
| 133 |
+
--total_steps "${TOTAL_STEPS}" \
|
| 134 |
+
--warmup_steps "${WARMUP_STEPS}" \
|
| 135 |
+
--log_every "${LOG_EVERY}" \
|
| 136 |
+
--eval_every "${EVAL_EVERY}" \
|
| 137 |
+
--save_every "${SAVE_EVERY}" \
|
| 138 |
+
--latest_every "${LATEST_EVERY}" \
|
| 139 |
+
--optimizer adamw \
|
| 140 |
+
--lr "${LR}" \
|
| 141 |
+
--lr_schedule cosine \
|
| 142 |
+
--min_lr "${MIN_LR}" \
|
| 143 |
+
--weight_decay "${WEIGHT_DECAY}" \
|
| 144 |
+
--output_weight_decay "${OUTPUT_WEIGHT_DECAY}" \
|
| 145 |
+
--adamw_param_groups nanogpt \
|
| 146 |
+
--adam_beta1 "${ADAM_BETA1}" \
|
| 147 |
+
--adam_beta2 "${ADAM_BETA2}" \
|
| 148 |
+
--adam_eps "${ADAM_EPS}" \
|
| 149 |
+
--ema_decay "${EMA_DECAY}" \
|
| 150 |
+
--ema_start_step "${EMA_START_STEP}" \
|
| 151 |
+
--grad_clip "${GRAD_CLIP}" \
|
| 152 |
+
--seed 123 \
|
| 153 |
+
--d_model 768 \
|
| 154 |
+
--cond_dim 128 \
|
| 155 |
+
--n_layers 12 \
|
| 156 |
+
--n_heads 12 \
|
| 157 |
+
--dim_ff 3072 \
|
| 158 |
+
--dropout 0.0 \
|
| 159 |
+
--no-output_bias \
|
| 160 |
+
--norm_type rmsnorm \
|
| 161 |
+
--model_type ddit \
|
| 162 |
+
--state_format prob \
|
| 163 |
+
--bridge dirichlet \
|
| 164 |
+
--target_loss hard_ce \
|
| 165 |
+
--loss_t_weight_mode "${LOSS_T_WEIGHT_MODE}" \
|
| 166 |
+
--loss_t_min_weight "${LOSS_T_MIN_WEIGHT}" \
|
| 167 |
+
--loss_t_drop_below "${LOSS_T_DROP_BELOW}" \
|
| 168 |
+
--target_prob 1.0 \
|
| 169 |
+
--min_t 0.0 \
|
| 170 |
+
--max_t 1.0 \
|
| 171 |
+
--t_sampling_mode "${T_SAMPLING_MODE}" \
|
| 172 |
+
--t_sampling_power "${T_SAMPLING_POWER}" \
|
| 173 |
+
--t_sampling_logit_mean "${T_SAMPLING_LOGIT_MEAN}" \
|
| 174 |
+
--t_sampling_logit_std "${T_SAMPLING_LOGIT_STD}" \
|
| 175 |
+
--t_sampling_eps "${T_SAMPLING_EPS}" \
|
| 176 |
+
--dual_t \
|
| 177 |
+
--corrupt_t_mode same \
|
| 178 |
+
--corrupt_min_t 0.0 \
|
| 179 |
+
--corrupt_max_t 1.0 \
|
| 180 |
+
--min_mask_ratio "${MIN_MASK_RATIO}" \
|
| 181 |
+
--max_mask_ratio "${MAX_MASK_RATIO}" \
|
| 182 |
+
--wrong_token_replace_prob 1.0 \
|
| 183 |
+
--wrong_token_schedule linear_t \
|
| 184 |
+
--wrong_token_exp_k 1.0 \
|
| 185 |
+
--dirichlet_concentration_min 1.0 \
|
| 186 |
+
--dirichlet_concentration_max 1024 \
|
| 187 |
+
--dirichlet_endpoint_mode categorical_dual_t \
|
| 188 |
+
--dirichlet_semantic_t_mode same \
|
| 189 |
+
--dirichlet_semantic_t_value 0.0 \
|
| 190 |
+
--categorical_wrong_from_full_vocab \
|
| 191 |
+
--simplex_bridge_sampler dirichlet \
|
| 192 |
+
--eps 1e-8 \
|
| 193 |
+
--infer_steps 1024 \
|
| 194 |
+
--decode_damping 1.0 \
|
| 195 |
+
--max_gamma 1.0 \
|
| 196 |
+
--decode_solver flowmap \
|
| 197 |
+
--noise_init logistic_normal \
|
| 198 |
+
--bridge_noise_init logistic_normal \
|
| 199 |
+
--noise_sigma -1 \
|
| 200 |
+
"${TF32_FLAG}" \
|
| 201 |
+
--ddp_gradient_as_bucket_view \
|
| 202 |
+
2>&1 | tee -a "${LOG_FILE}"
|
LTA_openwebtext_dualt/scripts/launch_lta_owt_fullycoupled_uniform_mask1_swiglu_wd0p1_fp32_8gpu.sh
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
|
| 7 |
+
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
|
| 8 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 9 |
+
export TOKENIZERS_PARALLELISM=false
|
| 10 |
+
export PYTHONUNBUFFERED=1
|
| 11 |
+
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
|
| 12 |
+
export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
|
| 13 |
+
|
| 14 |
+
# Fully-coupled OWT baseline with FP32 params/activations and nanoGPT-style AdamW decay:
|
| 15 |
+
# decay: matrix / embedding params (p.dim() >= 2)
|
| 16 |
+
# no decay: bias / norm / 1D params
|
| 17 |
+
# This intentionally does NOT pass --bf16. TF32 is enabled by default so H200 uses
|
| 18 |
+
# Tensor Cores; set ALLOW_TF32=0 for strict FP32 debugging.
|
| 19 |
+
|
| 20 |
+
T_SAMPLING_MODE="${T_SAMPLING_MODE:-uniform}"
|
| 21 |
+
T_SAMPLING_POWER="${T_SAMPLING_POWER:-1.0}"
|
| 22 |
+
T_SAMPLING_EPS="${T_SAMPLING_EPS:-1e-4}"
|
| 23 |
+
T_SAMPLING_LOGIT_MEAN="${T_SAMPLING_LOGIT_MEAN:--0.22}"
|
| 24 |
+
T_SAMPLING_LOGIT_STD="${T_SAMPLING_LOGIT_STD:-0.5}"
|
| 25 |
+
MIN_MASK_RATIO="${MIN_MASK_RATIO:-1.0}"
|
| 26 |
+
MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
|
| 27 |
+
DDIT_MLP_TYPE="${DDIT_MLP_TYPE:-swiglu}"
|
| 28 |
+
|
| 29 |
+
sanitize_label() {
|
| 30 |
+
printf "%s" "$1" | sed -e 's/-/m/g' -e 's/\./p/g'
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
T_LOGIT_MEAN_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_MEAN}")"
|
| 34 |
+
T_LOGIT_STD_LABEL="$(sanitize_label "${T_SAMPLING_LOGIT_STD}")"
|
| 35 |
+
MIN_MASK_RATIO_LABEL="$(sanitize_label "${MIN_MASK_RATIO}")"
|
| 36 |
+
MAX_MASK_RATIO_LABEL="$(sanitize_label "${MAX_MASK_RATIO}")"
|
| 37 |
+
|
| 38 |
+
RUN_NAME="${RUN_NAME:-lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_${DDIT_MLP_TYPE}_adamw_wd0p1_uniformt_hardce_mask${MIN_MASK_RATIO_LABEL}-${MAX_MASK_RATIO_LABEL}_nanogpt_fp32_ddit768x12_gbs512_8gpu_1m_$(date +%Y%m%d_%H%M%S)}"
|
| 39 |
+
SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
|
| 40 |
+
LOG_DIR="${LOG_DIR:-logs/fullycoupled_uniform_mask1_swiglu_wd0p1_fp32_8gpu}"
|
| 41 |
+
LOG_FILE="${LOG_FILE:-${LOG_DIR}/${RUN_NAME}.log}"
|
| 42 |
+
|
| 43 |
+
DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext}"
|
| 44 |
+
OWT_CACHE="${OWT_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k}"
|
| 45 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
|
| 46 |
+
|
| 47 |
+
NNODES="${NNODES:-1}"
|
| 48 |
+
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
| 49 |
+
NODE_RANK="${NODE_RANK:-0}"
|
| 50 |
+
MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
|
| 51 |
+
MASTER_PORT="${MASTER_PORT:-31997}"
|
| 52 |
+
|
| 53 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
|
| 54 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 55 |
+
TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
|
| 56 |
+
WARMUP_STEPS="${WARMUP_STEPS:-2000}"
|
| 57 |
+
NUM_WORKERS="${NUM_WORKERS:-8}"
|
| 58 |
+
DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
|
| 59 |
+
LOG_EVERY="${LOG_EVERY:-50}"
|
| 60 |
+
SAVE_EVERY="${SAVE_EVERY:-50000}"
|
| 61 |
+
LATEST_EVERY="${LATEST_EVERY:-1000}"
|
| 62 |
+
EVAL_EVERY="${EVAL_EVERY:-0}"
|
| 63 |
+
ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
|
| 64 |
+
ALLOW_TF32="${ALLOW_TF32:-1}"
|
| 65 |
+
DRY_RUN="${DRY_RUN:-0}"
|
| 66 |
+
|
| 67 |
+
LR="${LR:-6e-4}"
|
| 68 |
+
MIN_LR="${MIN_LR:-6e-5}"
|
| 69 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
|
| 70 |
+
OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}"
|
| 71 |
+
ADAM_BETA1="${ADAM_BETA1:-0.9}"
|
| 72 |
+
ADAM_BETA2="${ADAM_BETA2:-0.95}"
|
| 73 |
+
ADAM_EPS="${ADAM_EPS:-1e-8}"
|
| 74 |
+
GRAD_CLIP="${GRAD_CLIP:-1.0}"
|
| 75 |
+
EMA_DECAY="${EMA_DECAY:-0.0}"
|
| 76 |
+
EMA_START_STEP="${EMA_START_STEP:-0}"
|
| 77 |
+
LOSS_T_WEIGHT_MODE="${LOSS_T_WEIGHT_MODE:-none}"
|
| 78 |
+
LOSS_T_MIN_WEIGHT="${LOSS_T_MIN_WEIGHT:-0.0}"
|
| 79 |
+
LOSS_T_DROP_BELOW="${LOSS_T_DROP_BELOW:-0.2}"
|
| 80 |
+
|
| 81 |
+
if [[ -f "${SAVE_DIR}/args.json" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
|
| 82 |
+
echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
|
| 83 |
+
echo "Use a new RUN_NAME/SAVE_DIR or set ALLOW_EXISTING_SAVE_DIR=1 intentionally." >&2
|
| 84 |
+
exit 2
|
| 85 |
+
fi
|
| 86 |
+
|
| 87 |
+
mkdir -p "${LOG_DIR}" "${SAVE_DIR}"
|
| 88 |
+
|
| 89 |
+
TF32_FLAG="--allow_tf32"
|
| 90 |
+
TF32_LABEL="true"
|
| 91 |
+
if [[ "${ALLOW_TF32}" == "0" || "${ALLOW_TF32}" == "false" || "${ALLOW_TF32}" == "False" ]]; then
|
| 92 |
+
TF32_FLAG="--no-allow_tf32"
|
| 93 |
+
TF32_LABEL="false"
|
| 94 |
+
fi
|
| 95 |
+
|
| 96 |
+
echo "[launch] method=owt_fullycoupled_adamw_wd0p1_nanogpt_fp32 host=$(hostname) time=$(date -Iseconds)"
|
| 97 |
+
echo "[launch] run_name=${RUN_NAME}"
|
| 98 |
+
echo "[launch] save_dir=${SAVE_DIR}"
|
| 99 |
+
echo "[launch] log_file=${LOG_FILE}"
|
| 100 |
+
echo "[launch] data_path=${DATA_PATH}"
|
| 101 |
+
echo "[launch] owt_cache=${OWT_CACHE}"
|
| 102 |
+
echo "[launch] optimizer=adamw lr=${LR} min_lr=${MIN_LR} wd=${WEIGHT_DECAY} output_wd=${OUTPUT_WEIGHT_DECAY} param_groups=nanogpt ema=${EMA_DECAY}"
|
| 103 |
+
echo "[launch] fp32=true bf16=false tf32=${TF32_LABEL} norm_type=rmsnorm output_bias=false ddit_mlp_type=${DDIT_MLP_TYPE} batch=${GLOBAL_BATCH_SIZE} per_gpu=${PER_GPU_BATCH_SIZE}"
|
| 104 |
+
echo "[launch] loss_t_weight_mode=${LOSS_T_WEIGHT_MODE} loss_t_min_weight=${LOSS_T_MIN_WEIGHT} loss_t_drop_below=${LOSS_T_DROP_BELOW}"
|
| 105 |
+
echo "[launch] target_loss=hard_ce t_sampling_mode=${T_SAMPLING_MODE} t_sampling_logit_mean=${T_SAMPLING_LOGIT_MEAN} t_sampling_logit_std=${T_SAMPLING_LOGIT_STD} t_sampling_power=${T_SAMPLING_POWER} t_sampling_eps=${T_SAMPLING_EPS} mask_ratio=${MIN_MASK_RATIO}->${MAX_MASK_RATIO}"
|
| 106 |
+
|
| 107 |
+
if [[ "${DRY_RUN}" == "1" || "${DRY_RUN}" == "true" || "${DRY_RUN}" == "True" ]]; then
|
| 108 |
+
echo "[launch] DRY_RUN=1, validated launch parameters; skipping torchrun."
|
| 109 |
+
exit 0
|
| 110 |
+
fi
|
| 111 |
+
|
| 112 |
+
python -m torch.distributed.run \
|
| 113 |
+
--nnodes="${NNODES}" \
|
| 114 |
+
--nproc_per_node="${NPROC_PER_NODE}" \
|
| 115 |
+
--node_rank="${NODE_RANK}" \
|
| 116 |
+
--master_addr="${MASTER_ADDR}" \
|
| 117 |
+
--master_port="${MASTER_PORT}" \
|
| 118 |
+
train.py \
|
| 119 |
+
--data_path "${DATA_PATH}" \
|
| 120 |
+
--openwebtext_split train_minus_100k \
|
| 121 |
+
--text_column text \
|
| 122 |
+
--detokenizer auto \
|
| 123 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 124 |
+
--save_dir "${SAVE_DIR}" \
|
| 125 |
+
--wrap \
|
| 126 |
+
--wrap_mode stream \
|
| 127 |
+
--owt_cached_chunks \
|
| 128 |
+
--owt_chunk_cache_dir "${OWT_CACHE}" \
|
| 129 |
+
--max_len 1024 \
|
| 130 |
+
--batch_size "${PER_GPU_BATCH_SIZE}" \
|
| 131 |
+
--global_batch_size "${GLOBAL_BATCH_SIZE}" \
|
| 132 |
+
--num_workers "${NUM_WORKERS}" \
|
| 133 |
+
--dataloader_prefetch_factor "${DATALOADER_PREFETCH_FACTOR}" \
|
| 134 |
+
--total_steps "${TOTAL_STEPS}" \
|
| 135 |
+
--warmup_steps "${WARMUP_STEPS}" \
|
| 136 |
+
--log_every "${LOG_EVERY}" \
|
| 137 |
+
--eval_every "${EVAL_EVERY}" \
|
| 138 |
+
--save_every "${SAVE_EVERY}" \
|
| 139 |
+
--latest_every "${LATEST_EVERY}" \
|
| 140 |
+
--optimizer adamw \
|
| 141 |
+
--lr "${LR}" \
|
| 142 |
+
--lr_schedule cosine \
|
| 143 |
+
--min_lr "${MIN_LR}" \
|
| 144 |
+
--weight_decay "${WEIGHT_DECAY}" \
|
| 145 |
+
--output_weight_decay "${OUTPUT_WEIGHT_DECAY}" \
|
| 146 |
+
--adamw_param_groups nanogpt \
|
| 147 |
+
--adam_beta1 "${ADAM_BETA1}" \
|
| 148 |
+
--adam_beta2 "${ADAM_BETA2}" \
|
| 149 |
+
--adam_eps "${ADAM_EPS}" \
|
| 150 |
+
--ema_decay "${EMA_DECAY}" \
|
| 151 |
+
--ema_start_step "${EMA_START_STEP}" \
|
| 152 |
+
--grad_clip "${GRAD_CLIP}" \
|
| 153 |
+
--seed 123 \
|
| 154 |
+
--d_model 768 \
|
| 155 |
+
--cond_dim 128 \
|
| 156 |
+
--n_layers 12 \
|
| 157 |
+
--n_heads 12 \
|
| 158 |
+
--dim_ff 3072 \
|
| 159 |
+
--dropout 0.0 \
|
| 160 |
+
--no-output_bias \
|
| 161 |
+
--norm_type rmsnorm \
|
| 162 |
+
--model_type ddit \
|
| 163 |
+
--ddit_mlp_type "${DDIT_MLP_TYPE}" \
|
| 164 |
+
--state_format prob \
|
| 165 |
+
--bridge dirichlet \
|
| 166 |
+
--target_loss hard_ce \
|
| 167 |
+
--loss_t_weight_mode "${LOSS_T_WEIGHT_MODE}" \
|
| 168 |
+
--loss_t_min_weight "${LOSS_T_MIN_WEIGHT}" \
|
| 169 |
+
--loss_t_drop_below "${LOSS_T_DROP_BELOW}" \
|
| 170 |
+
--target_prob 1.0 \
|
| 171 |
+
--min_t 0.0 \
|
| 172 |
+
--max_t 1.0 \
|
| 173 |
+
--t_sampling_mode "${T_SAMPLING_MODE}" \
|
| 174 |
+
--t_sampling_power "${T_SAMPLING_POWER}" \
|
| 175 |
+
--t_sampling_logit_mean "${T_SAMPLING_LOGIT_MEAN}" \
|
| 176 |
+
--t_sampling_logit_std "${T_SAMPLING_LOGIT_STD}" \
|
| 177 |
+
--t_sampling_eps "${T_SAMPLING_EPS}" \
|
| 178 |
+
--dual_t \
|
| 179 |
+
--corrupt_t_mode same \
|
| 180 |
+
--corrupt_min_t 0.0 \
|
| 181 |
+
--corrupt_max_t 1.0 \
|
| 182 |
+
--min_mask_ratio "${MIN_MASK_RATIO}" \
|
| 183 |
+
--max_mask_ratio "${MAX_MASK_RATIO}" \
|
| 184 |
+
--wrong_token_replace_prob 1.0 \
|
| 185 |
+
--wrong_token_schedule linear_t \
|
| 186 |
+
--wrong_token_exp_k 1.0 \
|
| 187 |
+
--dirichlet_concentration_min 1.0 \
|
| 188 |
+
--dirichlet_concentration_max 1024 \
|
| 189 |
+
--dirichlet_endpoint_mode categorical_dual_t \
|
| 190 |
+
--dirichlet_semantic_t_mode same \
|
| 191 |
+
--dirichlet_semantic_t_value 0.0 \
|
| 192 |
+
--categorical_wrong_from_full_vocab \
|
| 193 |
+
--simplex_bridge_sampler dirichlet \
|
| 194 |
+
--eps 1e-8 \
|
| 195 |
+
--infer_steps 1024 \
|
| 196 |
+
--decode_damping 1.0 \
|
| 197 |
+
--max_gamma 1.0 \
|
| 198 |
+
--decode_solver flowmap \
|
| 199 |
+
--noise_init logistic_normal \
|
| 200 |
+
--bridge_noise_init logistic_normal \
|
| 201 |
+
--noise_sigma -1 \
|
| 202 |
+
"${TF32_FLAG}" \
|
| 203 |
+
--ddp_gradient_as_bucket_view \
|
| 204 |
+
2>&1 | tee -a "${LOG_FILE}"
|
LTA_openwebtext_dualt/scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 7 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}"
|
| 8 |
+
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
|
| 9 |
+
export TOKENIZERS_PARALLELISM=false
|
| 10 |
+
export PYTHONUNBUFFERED=1
|
| 11 |
+
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
|
| 12 |
+
export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
|
| 13 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 14 |
+
|
| 15 |
+
STAMP="$(date +%Y%m%d_%H%M%S)"
|
| 16 |
+
DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext}"
|
| 17 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
|
| 18 |
+
OWT_CHUNK_CACHE_DIR="${OWT_CHUNK_CACHE_DIR:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k}"
|
| 19 |
+
OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK:-0}"
|
| 20 |
+
|
| 21 |
+
NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
|
| 22 |
+
MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
|
| 23 |
+
MASTER_PORT="${MASTER_PORT:-32251}"
|
| 24 |
+
|
| 25 |
+
MAX_LEN="${MAX_LEN:-1024}"
|
| 26 |
+
LOWK_CLEAN_TOKENS="${LOWK_CLEAN_TOKENS:-64,128,256}"
|
| 27 |
+
CLEAN_STATE_MODE="${CLEAN_STATE_MODE:-onehot}"
|
| 28 |
+
MASK_MIXTURE_LOWK_PROB="${MASK_MIXTURE_LOWK_PROB:-1.0}"
|
| 29 |
+
MASK_MIXTURE_ALL_PROB="${MASK_MIXTURE_ALL_PROB:-0.0}"
|
| 30 |
+
VOCAB_SIZE_OVERRIDE="${VOCAB_SIZE_OVERRIDE:-0}"
|
| 31 |
+
D_MODEL="${D_MODEL:-768}"
|
| 32 |
+
COND_DIM="${COND_DIM:-128}"
|
| 33 |
+
N_LAYERS="${N_LAYERS:-12}"
|
| 34 |
+
N_HEADS="${N_HEADS:-12}"
|
| 35 |
+
DIM_FF="${DIM_FF:-3072}"
|
| 36 |
+
TOTAL_STEPS="${TOTAL_STEPS:-500}"
|
| 37 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
|
| 38 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 39 |
+
NUM_WORKERS="${NUM_WORKERS:-4}"
|
| 40 |
+
DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
|
| 41 |
+
MUON_IMPL="${MUON_IMPL:-legacy}"
|
| 42 |
+
OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}"
|
| 43 |
+
MIN_MASK_RATIO="${MIN_MASK_RATIO:-0.1}"
|
| 44 |
+
MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
|
| 45 |
+
MASK_RATIO_FLOOR_SCHEDULE="${MASK_RATIO_FLOOR_SCHEDULE:-none}"
|
| 46 |
+
DIRICHLET_CONCENTRATION_MIN="${DIRICHLET_CONCENTRATION_MIN:-1.0}"
|
| 47 |
+
DIRICHLET_CONCENTRATION_MAX="${DIRICHLET_CONCENTRATION_MAX:-1024}"
|
| 48 |
+
DIRICHLET_ENDPOINT_MODE="${DIRICHLET_ENDPOINT_MODE:-categorical_dual_t}"
|
| 49 |
+
DIRICHLET_SEMANTIC_T_MODE="${DIRICHLET_SEMANTIC_T_MODE:-same}"
|
| 50 |
+
DIRICHLET_SEMANTIC_T_VALUE="${DIRICHLET_SEMANTIC_T_VALUE:-0.0}"
|
| 51 |
+
SIMPLEX_BRIDGE_SAMPLER="${SIMPLEX_BRIDGE_SAMPLER:-dirichlet}"
|
| 52 |
+
ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA="${ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA:-0.0}"
|
| 53 |
+
CATEGORICAL_WRONG_BASIN_TOKEN_IDS="${CATEGORICAL_WRONG_BASIN_TOKEN_IDS:-}"
|
| 54 |
+
CATEGORICAL_WRONG_BASIN_PROB="${CATEGORICAL_WRONG_BASIN_PROB:-0.0}"
|
| 55 |
+
CATEGORICAL_WRONG_UNIGRAM_PROB="${CATEGORICAL_WRONG_UNIGRAM_PROB:-0.0}"
|
| 56 |
+
CATEGORICAL_WRONG_UNIFORM_PROB="${CATEGORICAL_WRONG_UNIFORM_PROB:-0.0}"
|
| 57 |
+
CATEGORICAL_WRONG_PROB_FLOOR="${CATEGORICAL_WRONG_PROB_FLOOR:-0.0}"
|
| 58 |
+
CATEGORICAL_WRONG_BASIN_SHARED_PROB="${CATEGORICAL_WRONG_BASIN_SHARED_PROB:-0.0}"
|
| 59 |
+
CATEGORICAL_WRONG_UNIGRAM_SHARED_PROB="${CATEGORICAL_WRONG_UNIGRAM_SHARED_PROB:-0.0}"
|
| 60 |
+
LOGISTIC_NORMAL_SIGMA_MIN="${LOGISTIC_NORMAL_SIGMA_MIN:-0.1}"
|
| 61 |
+
LOGISTIC_NORMAL_SIGMA_MAX="${LOGISTIC_NORMAL_SIGMA_MAX:-1.0}"
|
| 62 |
+
LOGISTIC_NORMAL_TAU_MIN="${LOGISTIC_NORMAL_TAU_MIN:-1.0}"
|
| 63 |
+
LOGISTIC_NORMAL_TAU_MAX="${LOGISTIC_NORMAL_TAU_MAX:-1.0}"
|
| 64 |
+
ROLLOUT_TRAIN_PROB="${ROLLOUT_TRAIN_PROB:-0.0}"
|
| 65 |
+
ROLLOUT_TRAIN_STEPS="${ROLLOUT_TRAIN_STEPS:-1}"
|
| 66 |
+
ROLLOUT_TRAIN_STEPS_MIN="${ROLLOUT_TRAIN_STEPS_MIN:--1}"
|
| 67 |
+
ROLLOUT_TRAIN_INFER_STEPS="${ROLLOUT_TRAIN_INFER_STEPS:-64}"
|
| 68 |
+
ROLLOUT_TRAIN_TIME_MODE="${ROLLOUT_TRAIN_TIME_MODE:-fixed_steps}"
|
| 69 |
+
ROLLOUT_TRAIN_S_DIST="${ROLLOUT_TRAIN_S_DIST:-uniform}"
|
| 70 |
+
ROLLOUT_TRAIN_S_MIN_FRAC="${ROLLOUT_TRAIN_S_MIN_FRAC:-0.0}"
|
| 71 |
+
ROLLOUT_TRAIN_S_MAX_FRAC="${ROLLOUT_TRAIN_S_MAX_FRAC:-0.125}"
|
| 72 |
+
ROLLOUT_TRAIN_S_BETA_ALPHA="${ROLLOUT_TRAIN_S_BETA_ALPHA:-2.0}"
|
| 73 |
+
ROLLOUT_TRAIN_S_BETA_BETA="${ROLLOUT_TRAIN_S_BETA_BETA:-6.0}"
|
| 74 |
+
ROLLOUT_TRAIN_TEMP="${ROLLOUT_TRAIN_TEMP:-1.0}"
|
| 75 |
+
ROLLOUT_TRAIN_MAX_GAMMA="${ROLLOUT_TRAIN_MAX_GAMMA:-1.0}"
|
| 76 |
+
ROLLOUT_TRAIN_CORRUPT_ONLY="${ROLLOUT_TRAIN_CORRUPT_ONLY:-1}"
|
| 77 |
+
ROLLOUT_TRAIN_SAMPLEWISE="${ROLLOUT_TRAIN_SAMPLEWISE:-0}"
|
| 78 |
+
ROLLOUT_TRAIN_SELECTED_ONLY="${ROLLOUT_TRAIN_SELECTED_ONLY:-0}"
|
| 79 |
+
ROLLOUT_TRAIN_COMPUTE_ALWAYS="${ROLLOUT_TRAIN_COMPUTE_ALWAYS:-0}"
|
| 80 |
+
ROLLOUT_TRAIN_SYNC_T="${ROLLOUT_TRAIN_SYNC_T:-0}"
|
| 81 |
+
|
| 82 |
+
TARGET_LOSS="${TARGET_LOSS:-linear_soft_kl}"
|
| 83 |
+
LINEAR_SOFT_TARGET_POWER="${LINEAR_SOFT_TARGET_POWER:-1.0}"
|
| 84 |
+
LINEAR_SOFT_TARGET_MIN_CONF="${LINEAR_SOFT_TARGET_MIN_CONF:-0.0}"
|
| 85 |
+
LINEAR_SOFT_TARGET_MAX_CONF="${LINEAR_SOFT_TARGET_MAX_CONF:-1.0}"
|
| 86 |
+
|
| 87 |
+
tag() {
|
| 88 |
+
printf "%s" "$1" | sed -e 's/,/_/g' -e 's/-/m/g' -e 's/\./p/g'
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
RUN_NAME="${RUN_NAME:-pilot_gpt2_softendpoint_mn_n${MAX_LEN}_m$(tag "${LOWK_CLEAN_TOKENS}")_${CLEAN_STATE_MODE}_${TOTAL_STEPS}_${STAMP}}"
|
| 92 |
+
SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
|
| 93 |
+
LOG_FILE="${LOG_FILE:-logs/softendpoint_mn_pilot_4gpu/${RUN_NAME}.log}"
|
| 94 |
+
RESUME_PATH="${RESUME_PATH:-}"
|
| 95 |
+
INIT_MODEL_PATH="${INIT_MODEL_PATH:-}"
|
| 96 |
+
mkdir -p "$(dirname "${LOG_FILE}")" "${SAVE_DIR}"
|
| 97 |
+
|
| 98 |
+
echo "[launch] gpt2 cached OWT soft-endpoint m/n pilot"
|
| 99 |
+
echo "[launch] run_name=${RUN_NAME}"
|
| 100 |
+
echo "[launch] save_dir=${SAVE_DIR}"
|
| 101 |
+
echo "[launch] n=${MAX_LEN} m=${LOWK_CLEAN_TOKENS} clean_state_mode=${CLEAN_STATE_MODE}"
|
| 102 |
+
echo "[launch] mask_mixture lowk=${MASK_MIXTURE_LOWK_PROB} all=${MASK_MIXTURE_ALL_PROB}"
|
| 103 |
+
echo "[launch] model d=${D_MODEL} layers=${N_LAYERS} heads=${N_HEADS} ff=${DIM_FF} vocab_override=${VOCAB_SIZE_OVERRIDE}"
|
| 104 |
+
echo "[launch] optimizer=muon muon_impl=${MUON_IMPL} weight_decay=${WEIGHT_DECAY:-0.1} output_weight_decay=${OUTPUT_WEIGHT_DECAY}"
|
| 105 |
+
echo "[launch] target_loss=${TARGET_LOSS} conf=${LINEAR_SOFT_TARGET_MIN_CONF}->${LINEAR_SOFT_TARGET_MAX_CONF} power=${LINEAR_SOFT_TARGET_POWER}"
|
| 106 |
+
echo "[launch] mask_ratio=${MIN_MASK_RATIO}->${MAX_MASK_RATIO}"
|
| 107 |
+
echo "[launch] mask_ratio_floor_schedule=${MASK_RATIO_FLOOR_SCHEDULE}"
|
| 108 |
+
echo "[launch] dirichlet C=${DIRICHLET_CONCENTRATION_MIN}->${DIRICHLET_CONCENTRATION_MAX} endpoint=${DIRICHLET_ENDPOINT_MODE} sampler=${SIMPLEX_BRIDGE_SAMPLER}"
|
| 109 |
+
echo "[launch] wrong_mix seq_alpha=${ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA} wrong_floor=${CATEGORICAL_WRONG_PROB_FLOOR} unigram=${CATEGORICAL_WRONG_UNIGRAM_PROB} uniform=${CATEGORICAL_WRONG_UNIFORM_PROB} basin=${CATEGORICAL_WRONG_BASIN_PROB} basin_ids=${CATEGORICAL_WRONG_BASIN_TOKEN_IDS}"
|
| 110 |
+
echo "[launch] rollout_train prob=${ROLLOUT_TRAIN_PROB} mode=${ROLLOUT_TRAIN_TIME_MODE} steps=${ROLLOUT_TRAIN_STEPS} steps_min=${ROLLOUT_TRAIN_STEPS_MIN} infer_steps=${ROLLOUT_TRAIN_INFER_STEPS} s_dist=${ROLLOUT_TRAIN_S_DIST} s_frac=${ROLLOUT_TRAIN_S_MIN_FRAC}->${ROLLOUT_TRAIN_S_MAX_FRAC} temp=${ROLLOUT_TRAIN_TEMP} corrupt_only=${ROLLOUT_TRAIN_CORRUPT_ONLY} samplewise=${ROLLOUT_TRAIN_SAMPLEWISE} selected_only=${ROLLOUT_TRAIN_SELECTED_ONLY} sync_t=${ROLLOUT_TRAIN_SYNC_T}"
|
| 111 |
+
echo "[launch] cache=${OWT_CHUNK_CACHE_DIR} exact_repeat_per_chunk=${OWT_EXACT_REPEAT_PER_CHUNK}"
|
| 112 |
+
if [[ -n "${RESUME_PATH}" ]]; then
|
| 113 |
+
echo "[launch] resume_path=${RESUME_PATH}"
|
| 114 |
+
fi
|
| 115 |
+
if [[ -n "${INIT_MODEL_PATH}" ]]; then
|
| 116 |
+
echo "[launch] init_model_path=${INIT_MODEL_PATH}"
|
| 117 |
+
fi
|
| 118 |
+
|
| 119 |
+
resume_args=()
|
| 120 |
+
if [[ -n "${RESUME_PATH}" ]]; then
|
| 121 |
+
resume_args=(--resume_path "${RESUME_PATH}")
|
| 122 |
+
fi
|
| 123 |
+
init_args=()
|
| 124 |
+
if [[ -n "${INIT_MODEL_PATH}" ]]; then
|
| 125 |
+
init_args=(--init_model_path "${INIT_MODEL_PATH}")
|
| 126 |
+
fi
|
| 127 |
+
rollout_bool_args=()
|
| 128 |
+
if [[ "${ROLLOUT_TRAIN_CORRUPT_ONLY}" == "0" || "${ROLLOUT_TRAIN_CORRUPT_ONLY}" == "false" || "${ROLLOUT_TRAIN_CORRUPT_ONLY}" == "False" ]]; then
|
| 129 |
+
rollout_bool_args+=(--no-rollout_train_corrupt_only)
|
| 130 |
+
else
|
| 131 |
+
rollout_bool_args+=(--rollout_train_corrupt_only)
|
| 132 |
+
fi
|
| 133 |
+
if [[ "${ROLLOUT_TRAIN_SAMPLEWISE}" == "1" || "${ROLLOUT_TRAIN_SAMPLEWISE}" == "true" || "${ROLLOUT_TRAIN_SAMPLEWISE}" == "True" ]]; then
|
| 134 |
+
rollout_bool_args+=(--rollout_train_samplewise)
|
| 135 |
+
fi
|
| 136 |
+
if [[ "${ROLLOUT_TRAIN_SELECTED_ONLY}" == "1" || "${ROLLOUT_TRAIN_SELECTED_ONLY}" == "true" || "${ROLLOUT_TRAIN_SELECTED_ONLY}" == "True" ]]; then
|
| 137 |
+
rollout_bool_args+=(--rollout_train_selected_only)
|
| 138 |
+
fi
|
| 139 |
+
if [[ "${ROLLOUT_TRAIN_COMPUTE_ALWAYS}" == "1" || "${ROLLOUT_TRAIN_COMPUTE_ALWAYS}" == "true" || "${ROLLOUT_TRAIN_COMPUTE_ALWAYS}" == "True" ]]; then
|
| 140 |
+
rollout_bool_args+=(--rollout_train_compute_always)
|
| 141 |
+
fi
|
| 142 |
+
if [[ "${ROLLOUT_TRAIN_SYNC_T}" == "1" || "${ROLLOUT_TRAIN_SYNC_T}" == "true" || "${ROLLOUT_TRAIN_SYNC_T}" == "True" ]]; then
|
| 143 |
+
rollout_bool_args+=(--rollout_train_sync_t)
|
| 144 |
+
fi
|
| 145 |
+
|
| 146 |
+
python -m torch.distributed.run \
|
| 147 |
+
--nnodes=1 \
|
| 148 |
+
--nproc_per_node="${NPROC_PER_NODE}" \
|
| 149 |
+
--node_rank=0 \
|
| 150 |
+
--master_addr="${MASTER_ADDR}" \
|
| 151 |
+
--master_port="${MASTER_PORT}" \
|
| 152 |
+
train.py \
|
| 153 |
+
--data_path "${DATA_PATH}" \
|
| 154 |
+
--text_column text \
|
| 155 |
+
--openwebtext_split train_minus_100k \
|
| 156 |
+
--detokenizer auto \
|
| 157 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 158 |
+
--save_dir "${SAVE_DIR}" \
|
| 159 |
+
"${resume_args[@]}" \
|
| 160 |
+
"${init_args[@]}" \
|
| 161 |
+
--wrap \
|
| 162 |
+
--wrap_mode stream \
|
| 163 |
+
--wrap_record_buffer_size 200 \
|
| 164 |
+
--owt_cached_chunks \
|
| 165 |
+
--owt_chunk_cache_dir "${OWT_CHUNK_CACHE_DIR}" \
|
| 166 |
+
--owt_chunk_cache_write_batch 4096 \
|
| 167 |
+
--owt_exact_repeat_per_chunk "${OWT_EXACT_REPEAT_PER_CHUNK}" \
|
| 168 |
+
--max_len "${MAX_LEN}" \
|
| 169 |
+
--batch_size "${PER_GPU_BATCH_SIZE}" \
|
| 170 |
+
--num_workers "${NUM_WORKERS}" \
|
| 171 |
+
--dataloader_prefetch_factor "${DATALOADER_PREFETCH_FACTOR}" \
|
| 172 |
+
--global_batch_size "${GLOBAL_BATCH_SIZE}" \
|
| 173 |
+
--total_steps "${TOTAL_STEPS}" \
|
| 174 |
+
--log_every "${LOG_EVERY:-50}" \
|
| 175 |
+
--eval_every 0 \
|
| 176 |
+
--save_every "${SAVE_EVERY:-500}" \
|
| 177 |
+
--latest_every "${LATEST_EVERY:-250}" \
|
| 178 |
+
--lr "${LEARNING_RATE:-0.002}" \
|
| 179 |
+
--weight_decay "${WEIGHT_DECAY:-0.1}" \
|
| 180 |
+
--output_weight_decay "${OUTPUT_WEIGHT_DECAY}" \
|
| 181 |
+
--adam_beta1 0.9 \
|
| 182 |
+
--adam_beta2 0.95 \
|
| 183 |
+
--adam_eps 1e-8 \
|
| 184 |
+
--optimizer muon \
|
| 185 |
+
--muon_impl "${MUON_IMPL}" \
|
| 186 |
+
--muon_momentum 0.95 \
|
| 187 |
+
--muon_ns_steps 5 \
|
| 188 |
+
--muon_update_scale 1.0 \
|
| 189 |
+
--ema_decay 0.9999 \
|
| 190 |
+
--ema_start_step 0 \
|
| 191 |
+
--warmup_steps "${WARMUP_STEPS:-100}" \
|
| 192 |
+
--lr_schedule constant_warmup \
|
| 193 |
+
--min_lr 0 \
|
| 194 |
+
--grad_clip 1.0 \
|
| 195 |
+
--adamw_param_groups nanogpt \
|
| 196 |
+
--seed 123 \
|
| 197 |
+
--d_model "${D_MODEL}" \
|
| 198 |
+
--cond_dim "${COND_DIM}" \
|
| 199 |
+
--n_layers "${N_LAYERS}" \
|
| 200 |
+
--n_heads "${N_HEADS}" \
|
| 201 |
+
--dim_ff "${DIM_FF}" \
|
| 202 |
+
--vocab_size_override "${VOCAB_SIZE_OVERRIDE}" \
|
| 203 |
+
--dropout 0.0 \
|
| 204 |
+
--model_type ddit \
|
| 205 |
+
--state_format prob \
|
| 206 |
+
--bridge dirichlet \
|
| 207 |
+
--target_loss "${TARGET_LOSS}" \
|
| 208 |
+
--linear_soft_target_power "${LINEAR_SOFT_TARGET_POWER}" \
|
| 209 |
+
--linear_soft_target_min_conf "${LINEAR_SOFT_TARGET_MIN_CONF}" \
|
| 210 |
+
--linear_soft_target_max_conf "${LINEAR_SOFT_TARGET_MAX_CONF}" \
|
| 211 |
+
--rollout_train_prob "${ROLLOUT_TRAIN_PROB}" \
|
| 212 |
+
--rollout_train_steps "${ROLLOUT_TRAIN_STEPS}" \
|
| 213 |
+
--rollout_train_steps_min "${ROLLOUT_TRAIN_STEPS_MIN}" \
|
| 214 |
+
--rollout_train_infer_steps "${ROLLOUT_TRAIN_INFER_STEPS}" \
|
| 215 |
+
--rollout_train_time_mode "${ROLLOUT_TRAIN_TIME_MODE}" \
|
| 216 |
+
--rollout_train_s_dist "${ROLLOUT_TRAIN_S_DIST}" \
|
| 217 |
+
--rollout_train_s_min_frac "${ROLLOUT_TRAIN_S_MIN_FRAC}" \
|
| 218 |
+
--rollout_train_s_max_frac "${ROLLOUT_TRAIN_S_MAX_FRAC}" \
|
| 219 |
+
--rollout_train_s_beta_alpha "${ROLLOUT_TRAIN_S_BETA_ALPHA}" \
|
| 220 |
+
--rollout_train_s_beta_beta "${ROLLOUT_TRAIN_S_BETA_BETA}" \
|
| 221 |
+
--rollout_train_temp "${ROLLOUT_TRAIN_TEMP}" \
|
| 222 |
+
--rollout_train_max_gamma "${ROLLOUT_TRAIN_MAX_GAMMA}" \
|
| 223 |
+
"${rollout_bool_args[@]}" \
|
| 224 |
+
--target_prob 1.0 \
|
| 225 |
+
--min_t 0.0 \
|
| 226 |
+
--max_t 1.0 \
|
| 227 |
+
--t_sampling_mode "${T_SAMPLING_MODE:-logit_normal}" \
|
| 228 |
+
--t_sampling_logit_mean "${T_LOGIT_MEAN:--1.5}" \
|
| 229 |
+
--t_sampling_logit_std "${T_LOGIT_STD:-0.8}" \
|
| 230 |
+
--t_sampling_eps 1e-4 \
|
| 231 |
+
--dual_t \
|
| 232 |
+
--corrupt_t_mode same \
|
| 233 |
+
--corrupt_min_t 0.0 \
|
| 234 |
+
--corrupt_max_t 1.0 \
|
| 235 |
+
--min_mask_ratio "${MIN_MASK_RATIO}" \
|
| 236 |
+
--max_mask_ratio "${MAX_MASK_RATIO}" \
|
| 237 |
+
--mask_ratio_floor_schedule "${MASK_RATIO_FLOOR_SCHEDULE}" \
|
| 238 |
+
--mask_mixture_lowk_prob "${MASK_MIXTURE_LOWK_PROB}" \
|
| 239 |
+
--mask_mixture_all_prob "${MASK_MIXTURE_ALL_PROB}" \
|
| 240 |
+
--mask_mixture_lowk_clean_tokens "${LOWK_CLEAN_TOKENS}" \
|
| 241 |
+
--clean_state_mode "${CLEAN_STATE_MODE}" \
|
| 242 |
+
--wrong_token_replace_prob 1.0 \
|
| 243 |
+
--wrong_token_schedule linear_t \
|
| 244 |
+
--wrong_token_exp_k 1.0 \
|
| 245 |
+
--dirichlet_concentration_min "${DIRICHLET_CONCENTRATION_MIN}" \
|
| 246 |
+
--dirichlet_concentration_max "${DIRICHLET_CONCENTRATION_MAX}" \
|
| 247 |
+
--dirichlet_endpoint_mode "${DIRICHLET_ENDPOINT_MODE}" \
|
| 248 |
+
--dirichlet_semantic_t_mode "${DIRICHLET_SEMANTIC_T_MODE}" \
|
| 249 |
+
--dirichlet_semantic_t_value "${DIRICHLET_SEMANTIC_T_VALUE}" \
|
| 250 |
+
--endpoint_sequence_random_prob_alpha "${ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA}" \
|
| 251 |
+
--categorical_wrong_basin_token_ids "${CATEGORICAL_WRONG_BASIN_TOKEN_IDS}" \
|
| 252 |
+
--categorical_wrong_basin_prob "${CATEGORICAL_WRONG_BASIN_PROB}" \
|
| 253 |
+
--categorical_wrong_unigram_prob "${CATEGORICAL_WRONG_UNIGRAM_PROB}" \
|
| 254 |
+
--categorical_wrong_uniform_prob "${CATEGORICAL_WRONG_UNIFORM_PROB}" \
|
| 255 |
+
--categorical_wrong_prob_floor "${CATEGORICAL_WRONG_PROB_FLOOR}" \
|
| 256 |
+
--categorical_wrong_basin_shared_prob "${CATEGORICAL_WRONG_BASIN_SHARED_PROB}" \
|
| 257 |
+
--categorical_wrong_unigram_shared_prob "${CATEGORICAL_WRONG_UNIGRAM_SHARED_PROB}" \
|
| 258 |
+
--categorical_wrong_from_full_vocab \
|
| 259 |
+
--simplex_bridge_sampler "${SIMPLEX_BRIDGE_SAMPLER}" \
|
| 260 |
+
--logistic_normal_sigma_min "${LOGISTIC_NORMAL_SIGMA_MIN}" \
|
| 261 |
+
--logistic_normal_sigma_max "${LOGISTIC_NORMAL_SIGMA_MAX}" \
|
| 262 |
+
--logistic_normal_tau_min "${LOGISTIC_NORMAL_TAU_MIN}" \
|
| 263 |
+
--logistic_normal_tau_max "${LOGISTIC_NORMAL_TAU_MAX}" \
|
| 264 |
+
--eps 1e-8 \
|
| 265 |
+
--noise_init logistic_normal \
|
| 266 |
+
--bridge_noise_init logistic_normal \
|
| 267 |
+
--noise_sigma -1 \
|
| 268 |
+
--allow_tf32 \
|
| 269 |
+
--ddp_gradient_as_bucket_view \
|
| 270 |
+
--bf16 \
|
| 271 |
+
2>&1 | tee -a "${LOG_FILE}"
|
LTA_openwebtext_dualt/scripts/launch_lta_owt_selfcond_p05_4gpu.sh
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}"
|
| 7 |
+
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
|
| 8 |
+
export TOKENIZERS_PARALLELISM=false
|
| 9 |
+
|
| 10 |
+
STAMP="$(date +%Y%m%d_%H%M%S)"
|
| 11 |
+
export RUN_NAME="${RUN_NAME:-lta_owt_gpt2cached_len1024_selfcond_p05_rollout1_samplewise_c1024_ddit768x12_muon_ema_gbs512_4gpu_50k_${STAMP}}"
|
| 12 |
+
export SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
|
| 13 |
+
export LOG_FILE="${LOG_FILE:-logs/selfcond_4gpu/${RUN_NAME}.log}"
|
| 14 |
+
mkdir -p "$(dirname "${LOG_FILE}")"
|
| 15 |
+
|
| 16 |
+
export NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
|
| 17 |
+
export MASTER_PORT="${MASTER_PORT:-31994}"
|
| 18 |
+
export PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
|
| 19 |
+
export GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 20 |
+
export TOTAL_STEPS="${TOTAL_STEPS:-50000}"
|
| 21 |
+
export WARMUP_STEPS="${WARMUP_STEPS:-2000}"
|
| 22 |
+
export MAX_LEN="${MAX_LEN:-1024}"
|
| 23 |
+
export NUM_WORKERS="${NUM_WORKERS:-4}"
|
| 24 |
+
export DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
|
| 25 |
+
export LOG_EVERY="${LOG_EVERY:-20}"
|
| 26 |
+
export SAVE_EVERY="${SAVE_EVERY:-1000}"
|
| 27 |
+
export LATEST_EVERY="${LATEST_EVERY:-500}"
|
| 28 |
+
export EVAL_EVERY="${EVAL_EVERY:-0}"
|
| 29 |
+
|
| 30 |
+
export OWT_CACHED_CHUNKS=1
|
| 31 |
+
export OWT_CHUNK_CACHE_REBUILD=0
|
| 32 |
+
export OWT_CHUNK_CACHE_DIR="${OWT_CHUNK_CACHE_DIR:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k}"
|
| 33 |
+
|
| 34 |
+
export OPTIMIZER="${OPTIMIZER:-muon}"
|
| 35 |
+
export LEARNING_RATE="${LEARNING_RATE:-0.002}"
|
| 36 |
+
export LR_SCHEDULE="${LR_SCHEDULE:-constant_warmup}"
|
| 37 |
+
export MIN_LR="${MIN_LR:-0}"
|
| 38 |
+
export WEIGHT_DECAY="${WEIGHT_DECAY:-0.0}"
|
| 39 |
+
export ADAM_BETA1="${ADAM_BETA1:-0.9}"
|
| 40 |
+
export ADAM_BETA2="${ADAM_BETA2:-0.95}"
|
| 41 |
+
export ADAM_EPS="${ADAM_EPS:-1e-8}"
|
| 42 |
+
export MUON_MOMENTUM="${MUON_MOMENTUM:-0.95}"
|
| 43 |
+
export MUON_NS_STEPS="${MUON_NS_STEPS:-5}"
|
| 44 |
+
export MUON_UPDATE_SCALE="${MUON_UPDATE_SCALE:-1.0}"
|
| 45 |
+
export EMA_DECAY="${EMA_DECAY:-0.9999}"
|
| 46 |
+
export EMA_START_STEP="${EMA_START_STEP:-0}"
|
| 47 |
+
|
| 48 |
+
export ALLOW_TF32="${ALLOW_TF32:-1}"
|
| 49 |
+
export ACTIVATION_CHECKPOINTING="${ACTIVATION_CHECKPOINTING:-0}"
|
| 50 |
+
export ACTIVATION_CHECKPOINT_INTERVAL="${ACTIVATION_CHECKPOINT_INTERVAL:-2}"
|
| 51 |
+
export DDP_GRADIENT_AS_BUCKET_VIEW="${DDP_GRADIENT_AS_BUCKET_VIEW:-1}"
|
| 52 |
+
export FULL_TRAIN_STATS="${FULL_TRAIN_STATS:-0}"
|
| 53 |
+
export ENABLE_TORCH_COMPILE="${ENABLE_TORCH_COMPILE:-0}"
|
| 54 |
+
export ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
|
| 55 |
+
|
| 56 |
+
export ROLLOUT_TRAIN_PROB="${ROLLOUT_TRAIN_PROB:-0.5}"
|
| 57 |
+
export ROLLOUT_TRAIN_STEPS="${ROLLOUT_TRAIN_STEPS:-1}"
|
| 58 |
+
export ROLLOUT_TRAIN_INFER_STEPS="${ROLLOUT_TRAIN_INFER_STEPS:-64}"
|
| 59 |
+
export ROLLOUT_TRAIN_TEMP="${ROLLOUT_TRAIN_TEMP:-1.45}"
|
| 60 |
+
export ROLLOUT_TRAIN_MAX_GAMMA="${ROLLOUT_TRAIN_MAX_GAMMA:-1.0}"
|
| 61 |
+
export ROLLOUT_TRAIN_CORRUPT_ONLY="${ROLLOUT_TRAIN_CORRUPT_ONLY:-1}"
|
| 62 |
+
export ROLLOUT_TRAIN_SAMPLEWISE="${ROLLOUT_TRAIN_SAMPLEWISE:-1}"
|
| 63 |
+
export ROLLOUT_TRAIN_COMPUTE_ALWAYS="${ROLLOUT_TRAIN_COMPUTE_ALWAYS:-0}"
|
| 64 |
+
|
| 65 |
+
bash scripts/launch_lta_owt_c1024_fullycoupled_8gpu_len1024_gpt2_cached_chunks_1m.sh
|
LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_blockar_twostream_8gpu.sh
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
set -o pipefail
|
| 5 |
+
|
| 6 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 7 |
+
|
| 8 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
|
| 9 |
+
export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
| 10 |
+
export MASTER_PORT="${MASTER_PORT:-32331}"
|
| 11 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 12 |
+
export TOKENIZERS_PARALLELISM=false
|
| 13 |
+
export PYTHONUNBUFFERED=1
|
| 14 |
+
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
|
| 15 |
+
|
| 16 |
+
DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-t5}"
|
| 17 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
|
| 18 |
+
|
| 19 |
+
# Two-stream packs a 1024-token example into a 2048-token model sequence, so the
|
| 20 |
+
# dense first version needs a smaller per-GPU batch than the regular 1024 runs.
|
| 21 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-128}"
|
| 22 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-4}"
|
| 23 |
+
TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
|
| 24 |
+
WARMUP_STEPS="${WARMUP_STEPS:-2000}"
|
| 25 |
+
MAX_RECORDS="${MAX_RECORDS:-0}"
|
| 26 |
+
|
| 27 |
+
D_MODEL="${D_MODEL:-768}"
|
| 28 |
+
N_LAYERS="${N_LAYERS:-12}"
|
| 29 |
+
N_HEADS="${N_HEADS:-12}"
|
| 30 |
+
DIM_FF="${DIM_FF:-3072}"
|
| 31 |
+
|
| 32 |
+
RUN_NAME="${RUN_NAME:-lta_owt_t5_blockar_twostream_b8x128_mask1_uniformt_adaln_swiglu_gbs${GLOBAL_BATCH_SIZE}_${NPROC_PER_NODE}gpu_$(date +%Y%m%d_%H%M%S)}"
|
| 33 |
+
|
| 34 |
+
mkdir -p logs/blockar_twostream runs/${RUN_NAME}
|
| 35 |
+
|
| 36 |
+
torchrun \
|
| 37 |
+
--nnodes=1 \
|
| 38 |
+
--nproc_per_node="${NPROC_PER_NODE}" \
|
| 39 |
+
--master_port="${MASTER_PORT}" \
|
| 40 |
+
train.py \
|
| 41 |
+
--data_path "${DATA_PATH}" \
|
| 42 |
+
--tokenized_hf \
|
| 43 |
+
--tokenized_pad_token pad \
|
| 44 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 45 |
+
--max_records "${MAX_RECORDS}" \
|
| 46 |
+
--save_dir "runs/${RUN_NAME}" \
|
| 47 |
+
--max_len 1024 \
|
| 48 |
+
--batch_size "${PER_GPU_BATCH_SIZE}" \
|
| 49 |
+
--global_batch_size "${GLOBAL_BATCH_SIZE}" \
|
| 50 |
+
--num_workers 8 \
|
| 51 |
+
--dataloader_prefetch_factor 4 \
|
| 52 |
+
--epochs 0 \
|
| 53 |
+
--total_steps "${TOTAL_STEPS}" \
|
| 54 |
+
--warmup_steps "${WARMUP_STEPS}" \
|
| 55 |
+
--log_every 100 \
|
| 56 |
+
--eval_every 0 \
|
| 57 |
+
--save_every 1000 \
|
| 58 |
+
--latest_every 500 \
|
| 59 |
+
--optimizer adamw \
|
| 60 |
+
--lr 6e-4 \
|
| 61 |
+
--lr_schedule cosine \
|
| 62 |
+
--min_lr 6e-5 \
|
| 63 |
+
--weight_decay 0.1 \
|
| 64 |
+
--output_weight_decay -1 \
|
| 65 |
+
--adamw_param_groups nanogpt \
|
| 66 |
+
--adam_beta1 0.9 \
|
| 67 |
+
--adam_beta2 0.999 \
|
| 68 |
+
--adam_eps 1e-8 \
|
| 69 |
+
--ema_decay 0.9999 \
|
| 70 |
+
--ema_start_step 0 \
|
| 71 |
+
--grad_clip 1.0 \
|
| 72 |
+
--seed 42 \
|
| 73 |
+
--d_model "${D_MODEL}" \
|
| 74 |
+
--cond_dim 128 \
|
| 75 |
+
--n_layers "${N_LAYERS}" \
|
| 76 |
+
--n_heads "${N_HEADS}" \
|
| 77 |
+
--dim_ff "${DIM_FF}" \
|
| 78 |
+
--dropout 0.0 \
|
| 79 |
+
--no-output_bias \
|
| 80 |
+
--norm_type rmsnorm \
|
| 81 |
+
--model_type ddit \
|
| 82 |
+
--ddit_mlp_type swiglu \
|
| 83 |
+
--state_format prob \
|
| 84 |
+
--bridge dirichlet \
|
| 85 |
+
--target_loss hard_ce \
|
| 86 |
+
--loss_t_weight_mode none \
|
| 87 |
+
--loss_t_min_weight 0.0 \
|
| 88 |
+
--block_ar_two_stream \
|
| 89 |
+
--block_ar_block_len 128 \
|
| 90 |
+
--target_prob 1.0 \
|
| 91 |
+
--min_t 0.0 \
|
| 92 |
+
--max_t 1.0 \
|
| 93 |
+
--t_sampling_mode uniform \
|
| 94 |
+
--dual_t \
|
| 95 |
+
--corrupt_t_mode same \
|
| 96 |
+
--corrupt_min_t 0.0 \
|
| 97 |
+
--corrupt_max_t 1.0 \
|
| 98 |
+
--min_mask_ratio 1.0 \
|
| 99 |
+
--max_mask_ratio 1.0 \
|
| 100 |
+
--mask_mixture_all_prob 1.0 \
|
| 101 |
+
--wrong_token_replace_prob 1.0 \
|
| 102 |
+
--wrong_token_schedule linear_t \
|
| 103 |
+
--dirichlet_concentration_min 1.0 \
|
| 104 |
+
--dirichlet_concentration_max 1024 \
|
| 105 |
+
--dirichlet_endpoint_mode categorical_dual_t \
|
| 106 |
+
--categorical_wrong_from_full_vocab \
|
| 107 |
+
--simplex_bridge_sampler dirichlet \
|
| 108 |
+
--eps 1e-8 \
|
| 109 |
+
--infer_steps 1024 \
|
| 110 |
+
--decode_damping 1.0 \
|
| 111 |
+
--max_gamma 1.0 \
|
| 112 |
+
--decode_solver flowmap \
|
| 113 |
+
--noise_init logistic_normal \
|
| 114 |
+
--bridge_noise_init logistic_normal \
|
| 115 |
+
--noise_sigma -1 \
|
| 116 |
+
--allow_tf32 \
|
| 117 |
+
--activation_checkpointing \
|
| 118 |
+
--activation_checkpoint_scope mlp \
|
| 119 |
+
--ddp_gradient_as_bucket_view \
|
| 120 |
+
2>&1 | tee -a "logs/blockar_twostream/${RUN_NAME}.log"
|
LTA_openwebtext_dualt/scripts/launch_lta_owt_t5_rollin_p50_randk0_4_8gpu.sh
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
|
| 7 |
+
export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
| 8 |
+
export MASTER_PORT="${MASTER_PORT:-32254}"
|
| 9 |
+
|
| 10 |
+
export DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-t5}"
|
| 11 |
+
export TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
|
| 12 |
+
|
| 13 |
+
export GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 14 |
+
export PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
|
| 15 |
+
export EPOCHS="${EPOCHS:-5}"
|
| 16 |
+
export TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
|
| 17 |
+
export NUM_WORKERS="${NUM_WORKERS:-8}"
|
| 18 |
+
export DATALOADER_PREFETCH_FACTOR="${DATALOADER_PREFETCH_FACTOR:-4}"
|
| 19 |
+
export LOG_EVERY="${LOG_EVERY:-100}"
|
| 20 |
+
export LATEST_EVERY="${LATEST_EVERY:-1000}"
|
| 21 |
+
export ALLOW_TF32="${ALLOW_TF32:-1}"
|
| 22 |
+
|
| 23 |
+
export LR="${LR:-0.002}"
|
| 24 |
+
export WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
|
| 25 |
+
export OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}"
|
| 26 |
+
export OUTPUT_INIT_STD="${OUTPUT_INIT_STD:-0.0}"
|
| 27 |
+
|
| 28 |
+
# Corrected roll-in: model-temperature rollout, synced sampled path, random K in {0,1,2,3,4}.
|
| 29 |
+
export T_SAMPLING_MODE="${T_SAMPLING_MODE:-uniform}"
|
| 30 |
+
export ROLLOUT_TRAIN_PROB="${ROLLOUT_TRAIN_PROB:-0.50}"
|
| 31 |
+
export ROLLOUT_TRAIN_TIME_MODE="${ROLLOUT_TRAIN_TIME_MODE:-sampled_path}"
|
| 32 |
+
export ROLLOUT_TRAIN_STEPS="${ROLLOUT_TRAIN_STEPS:-4}"
|
| 33 |
+
export ROLLOUT_TRAIN_STEPS_MIN="${ROLLOUT_TRAIN_STEPS_MIN:-0}"
|
| 34 |
+
export ROLLOUT_TRAIN_INFER_STEPS="${ROLLOUT_TRAIN_INFER_STEPS:-1}"
|
| 35 |
+
export ROLLOUT_TRAIN_S_DIST="${ROLLOUT_TRAIN_S_DIST:-uniform}"
|
| 36 |
+
export ROLLOUT_TRAIN_S_MIN_FRAC="${ROLLOUT_TRAIN_S_MIN_FRAC:-0.0}"
|
| 37 |
+
export ROLLOUT_TRAIN_S_MAX_FRAC="${ROLLOUT_TRAIN_S_MAX_FRAC:-0.25}"
|
| 38 |
+
export ROLLOUT_TRAIN_TEMP="${ROLLOUT_TRAIN_TEMP:-1.0}"
|
| 39 |
+
export ROLLOUT_TRAIN_MAX_GAMMA="${ROLLOUT_TRAIN_MAX_GAMMA:-1.0}"
|
| 40 |
+
export ROLLOUT_TRAIN_CORRUPT_ONLY="${ROLLOUT_TRAIN_CORRUPT_ONLY:-1}"
|
| 41 |
+
export ROLLOUT_TRAIN_SAMPLEWISE="${ROLLOUT_TRAIN_SAMPLEWISE:-1}"
|
| 42 |
+
export ROLLOUT_TRAIN_SELECTED_ONLY="${ROLLOUT_TRAIN_SELECTED_ONLY:-1}"
|
| 43 |
+
export ROLLOUT_TRAIN_COMPUTE_ALWAYS="${ROLLOUT_TRAIN_COMPUTE_ALWAYS:-0}"
|
| 44 |
+
export ROLLOUT_TRAIN_SYNC_T="${ROLLOUT_TRAIN_SYNC_T:-1}"
|
| 45 |
+
|
| 46 |
+
export MIN_MASK_RATIO="${MIN_MASK_RATIO:-1.0}"
|
| 47 |
+
export MAX_MASK_RATIO="${MAX_MASK_RATIO:-1.0}"
|
| 48 |
+
export MASK_MIXTURE_ORIGINAL_PROB="${MASK_MIXTURE_ORIGINAL_PROB:-0.0}"
|
| 49 |
+
export MASK_MIXTURE_LOWK_PROB="${MASK_MIXTURE_LOWK_PROB:-0.0}"
|
| 50 |
+
export MASK_MIXTURE_LOWCORRUPT_PROB="${MASK_MIXTURE_LOWCORRUPT_PROB:-0.0}"
|
| 51 |
+
export MASK_MIXTURE_BLOCK_PROB="${MASK_MIXTURE_BLOCK_PROB:-0.0}"
|
| 52 |
+
export MASK_MIXTURE_ALL_PROB="${MASK_MIXTURE_ALL_PROB:-1.0}"
|
| 53 |
+
|
| 54 |
+
export TARGET_LOSS="${TARGET_LOSS:-hard_ce}"
|
| 55 |
+
export T_LOGIT_MEAN="${T_LOGIT_MEAN:--1.5}"
|
| 56 |
+
export T_LOGIT_STD="${T_LOGIT_STD:-0.8}"
|
| 57 |
+
|
| 58 |
+
export RUN_NAME="${RUN_NAME:-lta_owt_t5_rollin_p50_randk0_4_uniformt_temp1_synct_mask1_gbs${GLOBAL_BATCH_SIZE}_${NPROC_PER_NODE}gpu_$(date +%Y%m%d_%H%M%S)}"
|
| 59 |
+
|
| 60 |
+
bash scripts/launch_lta_owt_elfaligned_t5_logitnormal_tokenized_8gpu.sh
|
LTA_openwebtext_dualt/scripts/make_compact_gpt2_bpe_tokenizer.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Iterable
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
SPECIAL_TOKENS = ["<pad>", "</s>", "<unk>", "<|endoftext|>"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse_args() -> argparse.Namespace:
|
| 14 |
+
p = argparse.ArgumentParser(
|
| 15 |
+
description=(
|
| 16 |
+
"Create compact byte-level GPT-2 BPE tokenizers by truncating the "
|
| 17 |
+
"merge table while preserving byte fallback coverage."
|
| 18 |
+
)
|
| 19 |
+
)
|
| 20 |
+
p.add_argument("--source_tokenizer", required=True)
|
| 21 |
+
p.add_argument("--output_dir", required=True)
|
| 22 |
+
p.add_argument("--vocab_size", type=int, required=True)
|
| 23 |
+
p.add_argument(
|
| 24 |
+
"--special_tokens",
|
| 25 |
+
nargs="*",
|
| 26 |
+
default=SPECIAL_TOKENS,
|
| 27 |
+
help="Special tokens to reserve first in the compact vocab.",
|
| 28 |
+
)
|
| 29 |
+
return p.parse_args()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def normalize_merge(merge: str | list[str]) -> tuple[str, str]:
|
| 33 |
+
if isinstance(merge, str):
|
| 34 |
+
left, right = merge.split(" ", 1)
|
| 35 |
+
return left, right
|
| 36 |
+
if isinstance(merge, list) and len(merge) == 2:
|
| 37 |
+
return str(merge[0]), str(merge[1])
|
| 38 |
+
raise ValueError(f"Unsupported merge entry: {merge!r}")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def merge_to_json(left: str, right: str, like: str | list[str]) -> str | list[str]:
|
| 42 |
+
if isinstance(like, str):
|
| 43 |
+
return f"{left} {right}"
|
| 44 |
+
return [left, right]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def base_byte_tokens(orig_id_to_token: dict[int, str]) -> Iterable[str]:
|
| 48 |
+
# GPT-2 byte-level BPE starts with exactly 256 byte fallback tokens.
|
| 49 |
+
for idx in range(256):
|
| 50 |
+
token = orig_id_to_token.get(idx)
|
| 51 |
+
if token is None:
|
| 52 |
+
raise ValueError(f"Source tokenizer is missing byte token id={idx}")
|
| 53 |
+
yield token
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def main() -> None:
|
| 57 |
+
args = parse_args()
|
| 58 |
+
src_path = Path(args.source_tokenizer)
|
| 59 |
+
out_dir = Path(args.output_dir)
|
| 60 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
with src_path.open("r", encoding="utf-8") as f:
|
| 63 |
+
tokenizer = json.load(f)
|
| 64 |
+
|
| 65 |
+
model = tokenizer["model"]
|
| 66 |
+
orig_vocab: dict[str, int] = {str(k): int(v) for k, v in model["vocab"].items()}
|
| 67 |
+
orig_id_to_token = {idx: tok for tok, idx in orig_vocab.items()}
|
| 68 |
+
orig_merges = list(model.get("merges", []))
|
| 69 |
+
|
| 70 |
+
special_tokens = []
|
| 71 |
+
for tok in args.special_tokens:
|
| 72 |
+
if tok not in special_tokens:
|
| 73 |
+
special_tokens.append(tok)
|
| 74 |
+
|
| 75 |
+
if args.vocab_size <= len(special_tokens) + 256:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
f"vocab_size={args.vocab_size} must exceed specials+byte alphabet "
|
| 78 |
+
f"({len(special_tokens)}+256)"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
ordered_tokens: list[str] = []
|
| 82 |
+
token_set: set[str] = set()
|
| 83 |
+
|
| 84 |
+
def add_token(tok: str) -> None:
|
| 85 |
+
if tok not in token_set:
|
| 86 |
+
token_set.add(tok)
|
| 87 |
+
ordered_tokens.append(tok)
|
| 88 |
+
|
| 89 |
+
for tok in special_tokens:
|
| 90 |
+
add_token(tok)
|
| 91 |
+
for tok in base_byte_tokens(orig_id_to_token):
|
| 92 |
+
add_token(tok)
|
| 93 |
+
|
| 94 |
+
compact_merges: list[str | list[str]] = []
|
| 95 |
+
for merge in orig_merges:
|
| 96 |
+
left, right = normalize_merge(merge)
|
| 97 |
+
out = left + right
|
| 98 |
+
if left not in token_set or right not in token_set:
|
| 99 |
+
continue
|
| 100 |
+
if out not in orig_vocab:
|
| 101 |
+
continue
|
| 102 |
+
if out not in token_set:
|
| 103 |
+
if len(ordered_tokens) >= args.vocab_size:
|
| 104 |
+
break
|
| 105 |
+
add_token(out)
|
| 106 |
+
if out in token_set:
|
| 107 |
+
compact_merges.append(merge_to_json(left, right, merge))
|
| 108 |
+
if len(ordered_tokens) >= args.vocab_size:
|
| 109 |
+
break
|
| 110 |
+
|
| 111 |
+
if len(ordered_tokens) != args.vocab_size:
|
| 112 |
+
raise RuntimeError(
|
| 113 |
+
f"Could only build vocab of size {len(ordered_tokens)}; "
|
| 114 |
+
f"requested {args.vocab_size}"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
compact_vocab = {tok: idx for idx, tok in enumerate(ordered_tokens)}
|
| 118 |
+
|
| 119 |
+
out = dict(tokenizer)
|
| 120 |
+
out["truncation"] = None
|
| 121 |
+
out["padding"] = None
|
| 122 |
+
out["added_tokens"] = [
|
| 123 |
+
{
|
| 124 |
+
"id": compact_vocab[tok],
|
| 125 |
+
"special": True,
|
| 126 |
+
"content": tok,
|
| 127 |
+
"single_word": False,
|
| 128 |
+
"lstrip": False,
|
| 129 |
+
"rstrip": False,
|
| 130 |
+
"normalized": True,
|
| 131 |
+
}
|
| 132 |
+
for tok in special_tokens
|
| 133 |
+
]
|
| 134 |
+
out["post_processor"] = None
|
| 135 |
+
out_model = dict(model)
|
| 136 |
+
out_model["vocab"] = compact_vocab
|
| 137 |
+
out_model["merges"] = compact_merges
|
| 138 |
+
out_model["unk_token"] = "<unk>"
|
| 139 |
+
out["model"] = out_model
|
| 140 |
+
|
| 141 |
+
with (out_dir / "tokenizer.json").open("w", encoding="utf-8") as f:
|
| 142 |
+
json.dump(out, f, ensure_ascii=False)
|
| 143 |
+
|
| 144 |
+
meta = {
|
| 145 |
+
"source_tokenizer": str(src_path),
|
| 146 |
+
"vocab_size": args.vocab_size,
|
| 147 |
+
"actual_vocab_size": len(compact_vocab),
|
| 148 |
+
"num_merges": len(compact_merges),
|
| 149 |
+
"special_tokens": special_tokens,
|
| 150 |
+
"construction": "compact_gpt2_bpe_prefix_merges",
|
| 151 |
+
}
|
| 152 |
+
with (out_dir / "tokenizer_meta.json").open("w", encoding="utf-8") as f:
|
| 153 |
+
json.dump(meta, f, indent=2, ensure_ascii=False)
|
| 154 |
+
f.write("\n")
|
| 155 |
+
|
| 156 |
+
print(json.dumps(meta, ensure_ascii=False))
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
main()
|
LTA_openwebtext_dualt/scripts/run_lta_lm1b_linear_simplex_len128_4gpu_save1k_with_watch.sh
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}"
|
| 7 |
+
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
|
| 8 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 9 |
+
export TOKENIZERS_PARALLELISM=false
|
| 10 |
+
export PYTHONUNBUFFERED=1
|
| 11 |
+
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
|
| 12 |
+
export TORCH_DISTRIBUTED_TIMEOUT="${TORCH_DISTRIBUTED_TIMEOUT:-3600}"
|
| 13 |
+
|
| 14 |
+
DATA_PATH="${DATA_PATH:-data/lm1b_train_parquet}"
|
| 15 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/workspace/imagenet_handoff_20260327/nlp_dts_light/assets/distilbert-base-uncased/tokenizer.json}"
|
| 16 |
+
SCORER="${SCORER:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-large-standard}"
|
| 17 |
+
|
| 18 |
+
NNODES="${NNODES:-1}"
|
| 19 |
+
NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
|
| 20 |
+
NODE_RANK="${NODE_RANK:-0}"
|
| 21 |
+
MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}"
|
| 22 |
+
MASTER_PORT="${MASTER_PORT:-32931}"
|
| 23 |
+
|
| 24 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 25 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-64}"
|
| 26 |
+
TOTAL_STEPS="${TOTAL_STEPS:-20000}"
|
| 27 |
+
WARMUP_STEPS="${WARMUP_STEPS:-2500}"
|
| 28 |
+
MAX_LEN="${MAX_LEN:-128}"
|
| 29 |
+
NUM_WORKERS="${NUM_WORKERS:-0}"
|
| 30 |
+
LOG_EVERY="${LOG_EVERY:-100}"
|
| 31 |
+
EVAL_EVERY="${EVAL_EVERY:-0}"
|
| 32 |
+
SAVE_EVERY="${SAVE_EVERY:-1000}"
|
| 33 |
+
LATEST_EVERY="${LATEST_EVERY:-1000}"
|
| 34 |
+
|
| 35 |
+
LR="${LR:-3e-4}"
|
| 36 |
+
D_MODEL="${D_MODEL:-768}"
|
| 37 |
+
COND_DIM="${COND_DIM:-128}"
|
| 38 |
+
N_LAYERS="${N_LAYERS:-12}"
|
| 39 |
+
N_HEADS="${N_HEADS:-12}"
|
| 40 |
+
DIM_FF="${DIM_FF:-3072}"
|
| 41 |
+
DROPOUT="${DROPOUT:-0.1}"
|
| 42 |
+
|
| 43 |
+
LINEAR_NOISE_SIGMA="${LINEAR_NOISE_SIGMA:-3.0}"
|
| 44 |
+
WRONG_TOKEN_REPLACE_PROB="${WRONG_TOKEN_REPLACE_PROB:-0.0}"
|
| 45 |
+
WRONG_TOKEN_SCHEDULE="${WRONG_TOKEN_SCHEDULE:-constant}"
|
| 46 |
+
ENDPOINT_TEMP_START="${ENDPOINT_TEMP_START:-1.45}"
|
| 47 |
+
ENDPOINT_TEMP_END="${ENDPOINT_TEMP_END:-0.8}"
|
| 48 |
+
ENDPOINT_PROJECTION="${ENDPOINT_PROJECTION:-soft}"
|
| 49 |
+
FINAL_FROM="${FINAL_FROM:-model_t1}"
|
| 50 |
+
|
| 51 |
+
ENABLE_WATCHER="${ENABLE_WATCHER:-1}"
|
| 52 |
+
WATCH_CUDA_VISIBLE_DEVICES="${WATCH_CUDA_VISIBLE_DEVICES:-0}"
|
| 53 |
+
WATCH_SAMPLES="${WATCH_SAMPLES:-64}"
|
| 54 |
+
WATCH_STEPS="${WATCH_STEPS:-128}"
|
| 55 |
+
WATCH_POLL_SECONDS="${WATCH_POLL_SECONDS:-30}"
|
| 56 |
+
|
| 57 |
+
RESUME_PATH="${RESUME_PATH:-}"
|
| 58 |
+
ALLOW_EXISTING_SAVE_DIR="${ALLOW_EXISTING_SAVE_DIR:-0}"
|
| 59 |
+
|
| 60 |
+
if [[ "${DATA_PATH}" == *"lm1b_train_parquet"* && "${NUM_WORKERS}" != "0" ]]; then
|
| 61 |
+
echo "LM1B has only 9 parquet shards; forcing NUM_WORKERS=0 to avoid empty DDP dataloader shards." >&2
|
| 62 |
+
NUM_WORKERS=0
|
| 63 |
+
fi
|
| 64 |
+
if [[ ! -e "${DATA_PATH}" ]]; then
|
| 65 |
+
echo "Missing DATA_PATH: ${DATA_PATH}" >&2
|
| 66 |
+
exit 2
|
| 67 |
+
fi
|
| 68 |
+
if [[ ! -f "${TOKENIZER_PATH}" ]]; then
|
| 69 |
+
echo "Missing TOKENIZER_PATH: ${TOKENIZER_PATH}" >&2
|
| 70 |
+
exit 2
|
| 71 |
+
fi
|
| 72 |
+
|
| 73 |
+
RUN_NAME="${RUN_NAME:-lta_lm1b_linear_simplex_len${MAX_LEN}_sigma${LINEAR_NOISE_SIGMA}_gbs${GLOBAL_BATCH_SIZE}_4gpu_20k_save1k_$(date +%Y%m%d_%H%M%S)}"
|
| 74 |
+
SAVE_DIR="${SAVE_DIR:-runs/${RUN_NAME}}"
|
| 75 |
+
LOG_DIR="${LOG_DIR:-logs/lm1b_linear_simplex_len128_4gpu}"
|
| 76 |
+
LOG_FILE="${LOG_FILE:-${LOG_DIR}/${RUN_NAME}.log}"
|
| 77 |
+
WATCH_LOG="${WATCH_LOG:-${LOG_DIR}/${RUN_NAME}.watch.log}"
|
| 78 |
+
|
| 79 |
+
if [[ -f "${SAVE_DIR}/args.json" && -z "${RESUME_PATH}" && "${ALLOW_EXISTING_SAVE_DIR}" != "1" ]]; then
|
| 80 |
+
echo "Refusing to start because SAVE_DIR already contains args.json: ${SAVE_DIR}" >&2
|
| 81 |
+
exit 2
|
| 82 |
+
fi
|
| 83 |
+
|
| 84 |
+
mkdir -p "${LOG_DIR}" runs "${SAVE_DIR}"
|
| 85 |
+
|
| 86 |
+
RESUME_ARGS=()
|
| 87 |
+
if [[ -n "${RESUME_PATH}" ]]; then
|
| 88 |
+
RESUME_ARGS+=(--resume_path "${RESUME_PATH}")
|
| 89 |
+
fi
|
| 90 |
+
|
| 91 |
+
echo "[launch] method=lm1b_linear_simplex_algebraic_len128_4gpu"
|
| 92 |
+
echo "[launch] host=$(hostname) time=$(date -Iseconds)"
|
| 93 |
+
echo "[launch] run_name=${RUN_NAME}"
|
| 94 |
+
echo "[launch] save_dir=${SAVE_DIR}"
|
| 95 |
+
echo "[launch] bridge=prob linear simplex: p_t=(1-t)*softmax(${LINEAR_NOISE_SIGMA}*N)+t*onehot(y)"
|
| 96 |
+
echo "[launch] wrong_endpoint: prob=${WRONG_TOKEN_REPLACE_PROB} schedule=${WRONG_TOKEN_SCHEDULE}"
|
| 97 |
+
echo "[launch] infer=linear algebraic: p_next=(1-s_next)*p0+s_next*a_theta"
|
| 98 |
+
echo "[launch] watcher=${ENABLE_WATCHER} watch_gpu=${WATCH_CUDA_VISIBLE_DEVICES} watch_samples=${WATCH_SAMPLES}"
|
| 99 |
+
|
| 100 |
+
run_watcher() {
|
| 101 |
+
local last_done=0
|
| 102 |
+
while true; do
|
| 103 |
+
for ckpt in "${SAVE_DIR}"/step_*.pt; do
|
| 104 |
+
[[ -f "${ckpt}" ]] || continue
|
| 105 |
+
local base step out_dir
|
| 106 |
+
base="$(basename "${ckpt}")"
|
| 107 |
+
step="${base#step_}"
|
| 108 |
+
step="${step%.pt}"
|
| 109 |
+
step=$((10#${step}))
|
| 110 |
+
if (( step <= last_done )); then
|
| 111 |
+
continue
|
| 112 |
+
fi
|
| 113 |
+
out_dir="docs/lta_samples/metrics_$(date +%Y%m%d)/${RUN_NAME}_step${step}_linear_alg_t${ENDPOINT_TEMP_START}_to_${ENDPOINT_TEMP_END}_n${WATCH_SAMPLES}"
|
| 114 |
+
echo "[watch] infer step=${step} ckpt=${ckpt} out=${out_dir}" | tee -a "${WATCH_LOG}"
|
| 115 |
+
CUDA_VISIBLE_DEVICES="${WATCH_CUDA_VISIBLE_DEVICES}" python scripts/eval_lm1b_linear_simplex_genppl.py \
|
| 116 |
+
--checkpoint "${ckpt}" \
|
| 117 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 118 |
+
--scorer "${SCORER}" \
|
| 119 |
+
--out_dir "${out_dir}" \
|
| 120 |
+
--n_samples "${WATCH_SAMPLES}" \
|
| 121 |
+
--max_len "${MAX_LEN}" \
|
| 122 |
+
--steps "${WATCH_STEPS}" \
|
| 123 |
+
--batch_size 1 \
|
| 124 |
+
--score_batch 1 \
|
| 125 |
+
--noise_init logistic_normal \
|
| 126 |
+
--noise_sigma "${LINEAR_NOISE_SIGMA}" \
|
| 127 |
+
--endpoint_temp_start "${ENDPOINT_TEMP_START}" \
|
| 128 |
+
--endpoint_temp_end "${ENDPOINT_TEMP_END}" \
|
| 129 |
+
--endpoint_projection "${ENDPOINT_PROJECTION}" \
|
| 130 |
+
--endpoint_top_p 1.0 \
|
| 131 |
+
--final_from "${FINAL_FROM}" \
|
| 132 |
+
--seed "$((20260524 + step))" 2>&1 | tee -a "${WATCH_LOG}" || true
|
| 133 |
+
last_done="${step}"
|
| 134 |
+
done
|
| 135 |
+
if (( last_done >= TOTAL_STEPS )); then
|
| 136 |
+
break
|
| 137 |
+
fi
|
| 138 |
+
sleep "${WATCH_POLL_SECONDS}"
|
| 139 |
+
done
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
if [[ "${ENABLE_WATCHER}" == "1" ]]; then
|
| 143 |
+
run_watcher &
|
| 144 |
+
WATCH_PID=$!
|
| 145 |
+
echo "[launch] watcher_pid=${WATCH_PID}"
|
| 146 |
+
fi
|
| 147 |
+
|
| 148 |
+
python -m torch.distributed.run \
|
| 149 |
+
--nnodes="${NNODES}" \
|
| 150 |
+
--nproc_per_node="${NPROC_PER_NODE}" \
|
| 151 |
+
--node_rank="${NODE_RANK}" \
|
| 152 |
+
--master_addr="${MASTER_ADDR}" \
|
| 153 |
+
--master_port="${MASTER_PORT}" \
|
| 154 |
+
train.py \
|
| 155 |
+
--data_path "${DATA_PATH}" \
|
| 156 |
+
--openwebtext_split all \
|
| 157 |
+
--detokenizer auto \
|
| 158 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 159 |
+
--save_dir "${SAVE_DIR}" \
|
| 160 |
+
--wrap \
|
| 161 |
+
--wrap_mode stream \
|
| 162 |
+
--wrap_record_buffer_size 200 \
|
| 163 |
+
--max_len "${MAX_LEN}" \
|
| 164 |
+
--batch_size "${PER_GPU_BATCH_SIZE}" \
|
| 165 |
+
--num_workers "${NUM_WORKERS}" \
|
| 166 |
+
--global_batch_size "${GLOBAL_BATCH_SIZE}" \
|
| 167 |
+
--total_steps "${TOTAL_STEPS}" \
|
| 168 |
+
--log_every "${LOG_EVERY}" \
|
| 169 |
+
--eval_every "${EVAL_EVERY}" \
|
| 170 |
+
--save_every "${SAVE_EVERY}" \
|
| 171 |
+
--latest_every "${LATEST_EVERY}" \
|
| 172 |
+
--lr "${LR}" \
|
| 173 |
+
--weight_decay 0 \
|
| 174 |
+
--adam_beta1 0.9 \
|
| 175 |
+
--adam_beta2 0.999 \
|
| 176 |
+
--adam_eps 1e-8 \
|
| 177 |
+
--warmup_steps "${WARMUP_STEPS}" \
|
| 178 |
+
--lr_schedule constant_warmup \
|
| 179 |
+
--grad_clip 1.0 \
|
| 180 |
+
--seed 123 \
|
| 181 |
+
--d_model "${D_MODEL}" \
|
| 182 |
+
--cond_dim "${COND_DIM}" \
|
| 183 |
+
--n_layers "${N_LAYERS}" \
|
| 184 |
+
--n_heads "${N_HEADS}" \
|
| 185 |
+
--dim_ff "${DIM_FF}" \
|
| 186 |
+
--dropout "${DROPOUT}" \
|
| 187 |
+
--model_type ddit \
|
| 188 |
+
--state_format prob \
|
| 189 |
+
--bridge prob \
|
| 190 |
+
--target_loss hard_ce \
|
| 191 |
+
--target_prob 1.0 \
|
| 192 |
+
--min_t 0.0 \
|
| 193 |
+
--max_t 1.0 \
|
| 194 |
+
--t_sampling_mode uniform \
|
| 195 |
+
--dual_t \
|
| 196 |
+
--corrupt_t_mode same \
|
| 197 |
+
--corrupt_min_t 0.0 \
|
| 198 |
+
--corrupt_max_t 1.0 \
|
| 199 |
+
--min_mask_ratio 1.0 \
|
| 200 |
+
--max_mask_ratio 1.0 \
|
| 201 |
+
--mask_mixture_all_prob 1.0 \
|
| 202 |
+
--wrong_token_replace_prob "${WRONG_TOKEN_REPLACE_PROB}" \
|
| 203 |
+
--wrong_token_schedule "${WRONG_TOKEN_SCHEDULE}" \
|
| 204 |
+
--eps 1e-8 \
|
| 205 |
+
--infer_steps 128 \
|
| 206 |
+
--decode_damping 1.0 \
|
| 207 |
+
--max_gamma 1.0 \
|
| 208 |
+
--decode_solver flowmap \
|
| 209 |
+
--noise_init logistic_normal \
|
| 210 |
+
--bridge_noise_init logistic_normal \
|
| 211 |
+
--noise_sigma "${LINEAR_NOISE_SIGMA}" \
|
| 212 |
+
"${RESUME_ARGS[@]}" \
|
| 213 |
+
--bf16 2>&1 | tee -a "${LOG_FILE}"
|
| 214 |
+
|
| 215 |
+
if [[ "${ENABLE_WATCHER}" == "1" ]]; then
|
| 216 |
+
wait "${WATCH_PID}" || true
|
| 217 |
+
fi
|
LTA_openwebtext_dualt/scripts/run_lta_owt_t5elf_absrope_time4_dirichlet_len1025_C1_to_1024_8gpu_1m_mask1_sameT_prebos_save10k.sh
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
# OWT through the existing ELF/T5 tokenized pipeline:
|
| 7 |
+
# [BOS] + up to 1024 existing T5 tokens.
|
| 8 |
+
#
|
| 9 |
+
# The ELF T5 records already end with </s>, so we only prepend BOS here.
|
| 10 |
+
# T5 has no separate BOS token in this tokenizer; the repo wrapper falls back
|
| 11 |
+
# to bos_id = eos_id = </s>.
|
| 12 |
+
#
|
| 13 |
+
# Backbone:
|
| 14 |
+
# ddit_elf = RMSNorm/SwiGLU/QK-norm + RoPE + 4 prefix time tokens.
|
| 15 |
+
# Learned absolute position embeddings are added before RoPE.
|
| 16 |
+
#
|
| 17 |
+
# Bridge:
|
| 18 |
+
# Dirichlet C=1->1024, mask_ratio=1.0, model t and corruption t are shared.
|
| 19 |
+
|
| 20 |
+
export DATA_PATH="${DATA_PATH:-/e2e-data/evad-tech-vla/wanghan58/data/embedded-language-flows/openwebtext-t5}"
|
| 21 |
+
export TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/hf/t5-small/tokenizer.json}"
|
| 22 |
+
export TOKENIZED_HF=1
|
| 23 |
+
export TOKENIZED_PAD_TOKEN="${TOKENIZED_PAD_TOKEN:-pad}"
|
| 24 |
+
export TOKENIZED_PREPEND_BOS=1
|
| 25 |
+
export TOKENIZED_APPEND_EOS=0
|
| 26 |
+
export TOKENIZED_STRIP_EDGE_SPECIALS=0
|
| 27 |
+
|
| 28 |
+
export MAX_LEN="${MAX_LEN:-1025}"
|
| 29 |
+
export VOCAB_SIZE="${VOCAB_SIZE:-32100}"
|
| 30 |
+
export CMIN="${CMIN:-1}"
|
| 31 |
+
export CMAX="${CMAX:-1024}"
|
| 32 |
+
|
| 33 |
+
export MODEL_TYPE=ddit_elf
|
| 34 |
+
export ELF_NUM_TIME_TOKENS="${ELF_NUM_TIME_TOKENS:-4}"
|
| 35 |
+
export ELF_NUM_MODEL_MODE_TOKENS="${ELF_NUM_MODEL_MODE_TOKENS:-0}"
|
| 36 |
+
export QK_NORM="${QK_NORM:-1}"
|
| 37 |
+
export ABS_POS_EMBED=1
|
| 38 |
+
export CORRUPT_T_MODE=same
|
| 39 |
+
export MIN_MASK_RATIO=1.0
|
| 40 |
+
export MAX_MASK_RATIO=1.0
|
| 41 |
+
export CATEGORICAL_WRONG_PROB_FLOOR="${CATEGORICAL_WRONG_PROB_FLOOR:-0.0}"
|
| 42 |
+
|
| 43 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
|
| 44 |
+
export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
| 45 |
+
export GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 46 |
+
export PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-32}"
|
| 47 |
+
export TOTAL_STEPS="${TOTAL_STEPS:-1000000}"
|
| 48 |
+
export WARMUP_STEPS="${WARMUP_STEPS:-2500}"
|
| 49 |
+
export SAVE_EVERY="${SAVE_EVERY:-10000}"
|
| 50 |
+
export LATEST_EVERY="${LATEST_EVERY:-1000}"
|
| 51 |
+
export LOG_EVERY="${LOG_EVERY:-100}"
|
| 52 |
+
|
| 53 |
+
export DATE_TAG="${DATE_TAG:-$(date +%Y%m%d)}"
|
| 54 |
+
export RUN_NAME="${RUN_NAME:-lta_owt_t5elf_absrope_time4_dirichlet_len1025_C1_to_1024_prebos_mask1_sameT_gbs512_b32_8gpu_1m_save10k_${DATE_TAG}}"
|
| 55 |
+
|
| 56 |
+
export WATCH_ENABLED="${WATCH_ENABLED:-1}"
|
| 57 |
+
export WATCH_STEP_INTERVAL="${WATCH_STEP_INTERVAL:-10000}"
|
| 58 |
+
export WATCH_N_SAMPLES="${WATCH_N_SAMPLES:-128}"
|
| 59 |
+
export WATCH_CUDA_VISIBLE_DEVICES="${WATCH_CUDA_VISIBLE_DEVICES:-7}"
|
| 60 |
+
export WATCH_DECODE_MODE="${WATCH_DECODE_MODE:-dual_line_probe}"
|
| 61 |
+
export WATCH_DECODE_BATCH="${WATCH_DECODE_BATCH:-1}"
|
| 62 |
+
export WATCH_DUAL_SEMANTIC_POWER="${WATCH_DUAL_SEMANTIC_POWER:-1.5}"
|
| 63 |
+
export WATCH_DUAL_EARLY_TEMP="${WATCH_DUAL_EARLY_TEMP:-2.8}"
|
| 64 |
+
export WATCH_DUAL_LATE_TEMP="${WATCH_DUAL_LATE_TEMP:-1.45}"
|
| 65 |
+
export WATCH_DUAL_TEMP_END="${WATCH_DUAL_TEMP_END:-0.55}"
|
| 66 |
+
export WATCH_DUAL_TEMP_POWER="${WATCH_DUAL_TEMP_POWER:-1.5}"
|
| 67 |
+
export WATCH_ENDPOINT_TEMP="${WATCH_ENDPOINT_TEMP:-1.45}"
|
| 68 |
+
export WATCH_ENDPOINT_TOP_P="${WATCH_ENDPOINT_TOP_P:-0.95}"
|
| 69 |
+
export WATCH_GUMBEL_TAU_START="${WATCH_GUMBEL_TAU_START:-1.0}"
|
| 70 |
+
export WATCH_GUMBEL_TAU_END="${WATCH_GUMBEL_TAU_END:-0.2}"
|
| 71 |
+
export WATCH_OUT_BASE="${WATCH_OUT_BASE:-docs/lta_samples/metrics_${DATE_TAG}/owt_t5elf_absrope_time4_len1025_C1_to_1024_prebos_mask1_sameT_dualline_dirres_c${CMIN}_${CMAX}_n${WATCH_N_SAMPLES}/${RUN_NAME}}"
|
| 72 |
+
export WATCH_LOG_DIR="${WATCH_LOG_DIR:-logs/owt_t5elf_absrope_time4_len1025_C1_to_1024_prebos_mask1_sameT_dualline_watch}"
|
| 73 |
+
|
| 74 |
+
bash scripts/run_lta_owt_dirichlet_len1024_Cv_to_2v_8gpu_save1k_with_gumbel_watch.sh
|
LTA_openwebtext_dualt/scripts/run_train8_ctx1024_rollin_sweep_4gpu.sh
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 7 |
+
export TOKENIZERS_PARALLELISM=false
|
| 8 |
+
export PYTHONUNBUFFERED=1
|
| 9 |
+
|
| 10 |
+
BASE_CACHE="${BASE_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks}"
|
| 11 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
|
| 12 |
+
MAX_LEN="${MAX_LEN:-1024}"
|
| 13 |
+
N_SAMPLES="${N_SAMPLES:-64}"
|
| 14 |
+
INFER_STEPS="${INFER_STEPS:-128}"
|
| 15 |
+
STEP_CHUNK="${STEP_CHUNK:-1000}"
|
| 16 |
+
MAX_TOTAL_STEPS="${MAX_TOTAL_STEPS:-12000}"
|
| 17 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-128}"
|
| 18 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 19 |
+
LEARNING_RATE="${LEARNING_RATE:-0.002}"
|
| 20 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
|
| 21 |
+
GROUP_STAMP="${GROUP_STAMP:-$(date +%Y%m%d_%H%M%S)}"
|
| 22 |
+
WAIT_FOR_RUN="${WAIT_FOR_RUN:-train8_rollin_len1024_rollin_p50_s4_i32_20260517_1855ctx1024bs128}"
|
| 23 |
+
OUT_ROOT="${OUT_ROOT:-docs/lta_samples/metrics_20260517/ctx1024_rollin_sweep_bs512_ode128_${GROUP_STAMP}}"
|
| 24 |
+
DRIVER_LOG="${DRIVER_LOG:-logs/ctx1024_rollin_sweep_4gpu/${GROUP_STAMP}.log}"
|
| 25 |
+
CURVE_CSV="${CURVE_CSV:-${OUT_ROOT}/hit_ratio_curve.csv}"
|
| 26 |
+
mkdir -p "$(dirname "${DRIVER_LOG}")" "${OUT_ROOT}"
|
| 27 |
+
|
| 28 |
+
cache="${BASE_CACHE}/gpt2_len${MAX_LEN}_train8_compact_overfit"
|
| 29 |
+
vocab_size="$(
|
| 30 |
+
python - "$cache" <<'PY'
|
| 31 |
+
import json
|
| 32 |
+
import sys
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
meta = json.loads((Path(sys.argv[1]) / "meta.json").read_text())
|
| 35 |
+
print(int(meta.get("compact_vocab_size", meta.get("vocab_size"))))
|
| 36 |
+
PY
|
| 37 |
+
)"
|
| 38 |
+
|
| 39 |
+
if [[ ! -f "${CURVE_CSV}" ]]; then
|
| 40 |
+
echo "config,run_name,ckpt_step,train_views_seen,train_tokens_seen,token_acc_mean,exact_count,exact_ref_count,exact_ref_hits" > "${CURVE_CSV}"
|
| 41 |
+
fi
|
| 42 |
+
|
| 43 |
+
latest_step() {
|
| 44 |
+
local run_name="$1"
|
| 45 |
+
python - "$run_name" <<'PY'
|
| 46 |
+
import re
|
| 47 |
+
import sys
|
| 48 |
+
from pathlib import Path
|
| 49 |
+
run = Path("runs") / sys.argv[1]
|
| 50 |
+
steps = []
|
| 51 |
+
for path in run.glob("step_*.pt"):
|
| 52 |
+
m = re.search(r"step_(\d+)\.pt$", path.name)
|
| 53 |
+
if m:
|
| 54 |
+
steps.append(int(m.group(1)))
|
| 55 |
+
print(max(steps) if steps else 0)
|
| 56 |
+
PY
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
free_port() {
|
| 60 |
+
python - <<'PY'
|
| 61 |
+
import socket
|
| 62 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 63 |
+
s.bind(("127.0.0.1", 0))
|
| 64 |
+
print(s.getsockname()[1])
|
| 65 |
+
PY
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
eval_latest() {
|
| 69 |
+
local config="$1"
|
| 70 |
+
local run_name="$2"
|
| 71 |
+
local target_step="$3"
|
| 72 |
+
local out_dir="${OUT_ROOT}/${config}/step_${target_step}"
|
| 73 |
+
mkdir -p "${out_dir}"
|
| 74 |
+
CUDA_VISIBLE_DEVICES="${EVAL_CUDA_VISIBLE_DEVICES:-0}" python scripts/eval_train8_decode_acc.py \
|
| 75 |
+
--runs_glob "runs/${run_name}" \
|
| 76 |
+
--data_dir "${cache}" \
|
| 77 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 78 |
+
--out_dir "${out_dir}" \
|
| 79 |
+
--max_len "${MAX_LEN}" \
|
| 80 |
+
--n_samples "${N_SAMPLES}" \
|
| 81 |
+
--batch_size "${N_SAMPLES}" \
|
| 82 |
+
--latest_only \
|
| 83 |
+
--endpoint_softenings none \
|
| 84 |
+
--steps "${INFER_STEPS}" \
|
| 85 |
+
--decode_rule flowmap \
|
| 86 |
+
--time_schedule logit_normal \
|
| 87 |
+
--time_logit_mean -1.5 \
|
| 88 |
+
--time_logit_std 0.8 \
|
| 89 |
+
--model_t_mode post \
|
| 90 |
+
--c_min 1 \
|
| 91 |
+
--c_max 512 \
|
| 92 |
+
--late_temp 1.0 \
|
| 93 |
+
--final_from state \
|
| 94 |
+
--final_decode argmax
|
| 95 |
+
python - "$out_dir" "$N_SAMPLES" "$GLOBAL_BATCH_SIZE" "$MAX_LEN" "$CURVE_CSV" "$config" "$run_name" <<'PY'
|
| 96 |
+
import json
|
| 97 |
+
import sys
|
| 98 |
+
from pathlib import Path
|
| 99 |
+
out = Path(sys.argv[1])
|
| 100 |
+
n = int(sys.argv[2])
|
| 101 |
+
global_batch = int(sys.argv[3])
|
| 102 |
+
max_len = int(sys.argv[4])
|
| 103 |
+
curve = Path(sys.argv[5])
|
| 104 |
+
config = sys.argv[6]
|
| 105 |
+
run_name = sys.argv[7]
|
| 106 |
+
row = json.loads((out / "decode_token_acc.jsonl").read_text().splitlines()[-1])
|
| 107 |
+
views = int(row["ckpt_step"]) * global_batch
|
| 108 |
+
tokens = views * max_len
|
| 109 |
+
print(
|
| 110 |
+
"RESULT "
|
| 111 |
+
f"config={config} run={run_name} ckpt_step={row['ckpt_step']} "
|
| 112 |
+
f"views={views} token_acc={row['token_acc_mean']:.4f} "
|
| 113 |
+
f"exact={row['exact_count']}/{n} exact_refs={row['exact_ref_count']} "
|
| 114 |
+
f"hits={row['exact_ref_hits']}",
|
| 115 |
+
flush=True,
|
| 116 |
+
)
|
| 117 |
+
with curve.open("a", encoding="utf-8") as f:
|
| 118 |
+
f.write(
|
| 119 |
+
f"{config},{run_name},{row['ckpt_step']},{views},{tokens},{row['token_acc_mean']},"
|
| 120 |
+
f"{row['exact_count']},{row['exact_ref_count']},\"{row['exact_ref_hits']}\"\n"
|
| 121 |
+
)
|
| 122 |
+
PY
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
run_config() {
|
| 126 |
+
local config="$1"
|
| 127 |
+
local rollout_prob="$2"
|
| 128 |
+
local rollout_steps="$3"
|
| 129 |
+
local rollout_infer_steps="$4"
|
| 130 |
+
local output_wd="$5"
|
| 131 |
+
local sync_t="$6"
|
| 132 |
+
local run_name="${RUN_PREFIX:-train8_ctx1024}_${config}_${GROUP_STAMP}"
|
| 133 |
+
echo "[ctx1024-sweep] config=${config} run=${run_name} p=${rollout_prob} steps=${rollout_steps} infer=${rollout_infer_steps} outwd=${output_wd} sync_t=${sync_t}" | tee -a "${DRIVER_LOG}"
|
| 134 |
+
while :; do
|
| 135 |
+
local step_now
|
| 136 |
+
step_now="$(latest_step "${run_name}")"
|
| 137 |
+
if [[ "${step_now}" -ge "${MAX_TOTAL_STEPS}" ]]; then
|
| 138 |
+
echo "[ctx1024-sweep] capped config=${config} step=${step_now}" | tee -a "${DRIVER_LOG}"
|
| 139 |
+
break
|
| 140 |
+
fi
|
| 141 |
+
local target_step=$((step_now + STEP_CHUNK))
|
| 142 |
+
if [[ "${target_step}" -gt "${MAX_TOTAL_STEPS}" ]]; then
|
| 143 |
+
target_step="${MAX_TOTAL_STEPS}"
|
| 144 |
+
fi
|
| 145 |
+
local resume_path=""
|
| 146 |
+
if [[ -f "runs/${run_name}/latest.pt" ]]; then
|
| 147 |
+
resume_path="runs/${run_name}/latest.pt"
|
| 148 |
+
fi
|
| 149 |
+
echo "[ctx1024-sweep] train config=${config} from=${step_now} to=${target_step}" | tee -a "${DRIVER_LOG}"
|
| 150 |
+
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}" \
|
| 151 |
+
NPROC_PER_NODE="${NPROC_PER_NODE:-4}" \
|
| 152 |
+
MASTER_PORT="$(free_port)" \
|
| 153 |
+
OWT_CHUNK_CACHE_DIR="${cache}" \
|
| 154 |
+
OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK:-64}" \
|
| 155 |
+
MAX_LEN="${MAX_LEN}" \
|
| 156 |
+
VOCAB_SIZE_OVERRIDE="${vocab_size}" \
|
| 157 |
+
D_MODEL="${D_MODEL:-192}" \
|
| 158 |
+
COND_DIM="${COND_DIM:-64}" \
|
| 159 |
+
N_LAYERS="${N_LAYERS:-3}" \
|
| 160 |
+
N_HEADS="${N_HEADS:-3}" \
|
| 161 |
+
DIM_FF="${DIM_FF:-768}" \
|
| 162 |
+
TOTAL_STEPS="${target_step}" \
|
| 163 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE}" \
|
| 164 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE}" \
|
| 165 |
+
NUM_WORKERS="${NUM_WORKERS:-0}" \
|
| 166 |
+
LOG_EVERY="${LOG_EVERY:-100}" \
|
| 167 |
+
SAVE_EVERY="${STEP_CHUNK}" \
|
| 168 |
+
LATEST_EVERY="${STEP_CHUNK}" \
|
| 169 |
+
WARMUP_STEPS="${WARMUP_STEPS:-10}" \
|
| 170 |
+
LEARNING_RATE="${LEARNING_RATE}" \
|
| 171 |
+
WEIGHT_DECAY="${WEIGHT_DECAY}" \
|
| 172 |
+
OUTPUT_WEIGHT_DECAY="${output_wd}" \
|
| 173 |
+
MUON_IMPL="${MUON_IMPL:-legacy}" \
|
| 174 |
+
MIN_MASK_RATIO=1.0 \
|
| 175 |
+
MAX_MASK_RATIO=1.0 \
|
| 176 |
+
MASK_MIXTURE_LOWK_PROB=0.0 \
|
| 177 |
+
MASK_MIXTURE_ALL_PROB=1.0 \
|
| 178 |
+
LOWK_CLEAN_TOKENS=0 \
|
| 179 |
+
CLEAN_STATE_MODE=onehot \
|
| 180 |
+
TARGET_LOSS=hard_ce \
|
| 181 |
+
DIRICHLET_CONCENTRATION_MIN=1.0 \
|
| 182 |
+
DIRICHLET_CONCENTRATION_MAX=1024 \
|
| 183 |
+
SIMPLEX_BRIDGE_SAMPLER=dirichlet \
|
| 184 |
+
CATEGORICAL_WRONG_PROB_FLOOR=0.0 \
|
| 185 |
+
ROLLOUT_TRAIN_PROB="${rollout_prob}" \
|
| 186 |
+
ROLLOUT_TRAIN_STEPS="${rollout_steps}" \
|
| 187 |
+
ROLLOUT_TRAIN_INFER_STEPS="${rollout_infer_steps}" \
|
| 188 |
+
ROLLOUT_TRAIN_TEMP=1.45 \
|
| 189 |
+
ROLLOUT_TRAIN_MAX_GAMMA=1.0 \
|
| 190 |
+
ROLLOUT_TRAIN_CORRUPT_ONLY=1 \
|
| 191 |
+
ROLLOUT_TRAIN_SAMPLEWISE=1 \
|
| 192 |
+
ROLLOUT_TRAIN_SELECTED_ONLY=1 \
|
| 193 |
+
ROLLOUT_TRAIN_COMPUTE_ALWAYS=0 \
|
| 194 |
+
ROLLOUT_TRAIN_SYNC_T="${sync_t}" \
|
| 195 |
+
RUN_NAME="${run_name}" \
|
| 196 |
+
RESUME_PATH="${resume_path}" \
|
| 197 |
+
bash scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh
|
| 198 |
+
echo "[ctx1024-sweep] eval config=${config} step=${target_step}" | tee -a "${DRIVER_LOG}"
|
| 199 |
+
eval_latest "${config}" "${run_name}" "${target_step}" | tee -a "${DRIVER_LOG}"
|
| 200 |
+
if tail -n 1 "${CURVE_CSV}" | awk -F, '{exit ($7 >= 60) ? 0 : 1}'; then
|
| 201 |
+
echo "[ctx1024-sweep] early-hit config=${config}" | tee -a "${DRIVER_LOG}"
|
| 202 |
+
break
|
| 203 |
+
fi
|
| 204 |
+
done
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
if [[ -n "${WAIT_FOR_RUN}" ]]; then
|
| 208 |
+
echo "[ctx1024-sweep] waiting for run=${WAIT_FOR_RUN}" | tee -a "${DRIVER_LOG}"
|
| 209 |
+
while pgrep -f "${WAIT_FOR_RUN}" >/dev/null; do
|
| 210 |
+
sleep 60
|
| 211 |
+
done
|
| 212 |
+
fi
|
| 213 |
+
|
| 214 |
+
echo "[ctx1024-sweep] start stamp=${GROUP_STAMP} len=${MAX_LEN} vocab=${vocab_size} out=${OUT_ROOT}" | tee -a "${DRIVER_LOG}"
|
| 215 |
+
|
| 216 |
+
CONFIGS=(
|
| 217 |
+
"p75_s4_i32_outwdm1|0.75|4|32|-1|0"
|
| 218 |
+
"p100_s4_i32_outwdm1|1.0|4|32|-1|0"
|
| 219 |
+
"p50_s8_i64_outwdm1|0.50|8|64|-1|0"
|
| 220 |
+
"p50_s4_i32_sync1_outwdm1|0.50|4|32|-1|1"
|
| 221 |
+
"p50_s8_i64_sync1_outwdm1|0.50|8|64|-1|1"
|
| 222 |
+
"p50_s4_i32_outwd0p3|0.50|4|32|0.3|0"
|
| 223 |
+
"p50_s4_i32_outwd1p0|0.50|4|32|1.0|0"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if [[ -n "${SWEEP_CONFIGS:-}" ]]; then
|
| 227 |
+
IFS=$'\n' read -r -d '' -a CONFIGS < <(printf '%s\0' "${SWEEP_CONFIGS}") || true
|
| 228 |
+
fi
|
| 229 |
+
|
| 230 |
+
for entry in "${CONFIGS[@]}"; do
|
| 231 |
+
IFS='|' read -r config rollout_prob rollout_steps rollout_infer_steps output_wd sync_t <<<"${entry}"
|
| 232 |
+
run_config "${config}" "${rollout_prob}" "${rollout_steps}" "${rollout_infer_steps}" "${output_wd}" "${sync_t}"
|
| 233 |
+
done
|
| 234 |
+
|
| 235 |
+
echo "[ctx1024-sweep] done" | tee -a "${DRIVER_LOG}"
|
LTA_openwebtext_dualt/scripts/run_train8_len256_headwd_sweep_bs512_until_exact_4gpu.sh
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 7 |
+
export TOKENIZERS_PARALLELISM=false
|
| 8 |
+
export PYTHONUNBUFFERED=1
|
| 9 |
+
|
| 10 |
+
BASE_CACHE="${BASE_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks}"
|
| 11 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
|
| 12 |
+
MAX_LEN="${MAX_LEN:-256}"
|
| 13 |
+
N_SAMPLES="${N_SAMPLES:-64}"
|
| 14 |
+
STOP_EXACT_COUNT="${STOP_EXACT_COUNT:-1}"
|
| 15 |
+
STOP_EXACT_REF_COUNT="${STOP_EXACT_REF_COUNT:-0}"
|
| 16 |
+
INFER_STEPS="${INFER_STEPS:-128}"
|
| 17 |
+
STEP_CHUNK="${STEP_CHUNK:-500}"
|
| 18 |
+
MAX_TOTAL_STEPS="${MAX_TOTAL_STEPS:-100000}"
|
| 19 |
+
MASTER_PORT_BASE="${MASTER_PORT_BASE:-38200}"
|
| 20 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-128}"
|
| 21 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 22 |
+
OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK:-64}"
|
| 23 |
+
LEARNING_RATE="${LEARNING_RATE:-0.002}"
|
| 24 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}"
|
| 25 |
+
HEAD_WD_VALUES="${HEAD_WD_VALUES:-0.1 0.3 1.0 3.0}"
|
| 26 |
+
OUT_ROOT="${OUT_ROOT:-docs/lta_samples/metrics_20260517/train8_len256_headwd_sweep_bs512_until_exact_ode128}"
|
| 27 |
+
DRIVER_LOG="${DRIVER_LOG:-logs/train8_len256_headwd_sweep_bs512_until_exact_4gpu/driver.log}"
|
| 28 |
+
SUMMARY_JSONL="${SUMMARY_JSONL:-${OUT_ROOT}/summary_results.jsonl}"
|
| 29 |
+
CURVE_CSV="${CURVE_CSV:-${OUT_ROOT}/hit_ratio_curve.csv}"
|
| 30 |
+
mkdir -p "$(dirname "${DRIVER_LOG}")" "${OUT_ROOT}"
|
| 31 |
+
|
| 32 |
+
tag_float() {
|
| 33 |
+
printf "%s" "$1" | sed -e 's/-/m/g' -e 's/\./p/g'
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
vocab_size_for_len() {
|
| 37 |
+
python - "$BASE_CACHE" "$MAX_LEN" <<'PY'
|
| 38 |
+
import json
|
| 39 |
+
import sys
|
| 40 |
+
from pathlib import Path
|
| 41 |
+
base = Path(sys.argv[1])
|
| 42 |
+
max_len = int(sys.argv[2])
|
| 43 |
+
meta = json.loads((base / f"gpt2_len{max_len}_train8_compact_overfit" / "meta.json").read_text())
|
| 44 |
+
print(int(meta.get("compact_vocab_size", meta.get("vocab_size"))))
|
| 45 |
+
PY
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
latest_step() {
|
| 49 |
+
local run_name="$1"
|
| 50 |
+
python - "$run_name" <<'PY'
|
| 51 |
+
import re
|
| 52 |
+
import sys
|
| 53 |
+
from pathlib import Path
|
| 54 |
+
run = Path("runs") / sys.argv[1]
|
| 55 |
+
steps = []
|
| 56 |
+
for path in run.glob("step_*.pt"):
|
| 57 |
+
m = re.search(r"step_(\d+)\.pt$", path.name)
|
| 58 |
+
if m:
|
| 59 |
+
steps.append(int(m.group(1)))
|
| 60 |
+
print(max(steps) if steps else 0)
|
| 61 |
+
PY
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
refresh_curve() {
|
| 65 |
+
python - "$SUMMARY_JSONL" "$CURVE_CSV" <<'PY'
|
| 66 |
+
import csv
|
| 67 |
+
import json
|
| 68 |
+
import re
|
| 69 |
+
import sys
|
| 70 |
+
from pathlib import Path
|
| 71 |
+
|
| 72 |
+
summary = Path(sys.argv[1])
|
| 73 |
+
out = Path(sys.argv[2])
|
| 74 |
+
rows = []
|
| 75 |
+
seen = set()
|
| 76 |
+
if summary.exists():
|
| 77 |
+
for line in summary.read_text().splitlines():
|
| 78 |
+
if not line.strip():
|
| 79 |
+
continue
|
| 80 |
+
row = json.loads(line)
|
| 81 |
+
m = re.search(
|
| 82 |
+
r"train8_n(\d+)_compactv\d+_small_headwd([^_]+)_bs512_hard_ce_(onehot|allcorrupt)",
|
| 83 |
+
row.get("run", ""),
|
| 84 |
+
)
|
| 85 |
+
if not m:
|
| 86 |
+
continue
|
| 87 |
+
max_len = int(m.group(1))
|
| 88 |
+
head_wd = m.group(2).replace("p", ".").replace("m", "-")
|
| 89 |
+
mode = m.group(3)
|
| 90 |
+
step = int(row["ckpt_step"])
|
| 91 |
+
key = (max_len, head_wd, mode, step)
|
| 92 |
+
if key in seen:
|
| 93 |
+
continue
|
| 94 |
+
seen.add(key)
|
| 95 |
+
n_gen = float(row.get("n_gen", 64))
|
| 96 |
+
rows.append({
|
| 97 |
+
"context_len": max_len,
|
| 98 |
+
"head_weight_decay": head_wd,
|
| 99 |
+
"mode": mode,
|
| 100 |
+
"step": step,
|
| 101 |
+
"train_views_seen": int(row.get("train_views_seen", step * 512)),
|
| 102 |
+
"train_tokens_seen": int(row.get("train_tokens_seen", step * 512 * max_len)),
|
| 103 |
+
"hit_ratio": float(row.get("exact_count", 0)) / n_gen,
|
| 104 |
+
"exact_count": int(row.get("exact_count", 0)),
|
| 105 |
+
"ref_coverage": float(row.get("exact_ref_count", 0)) / 8.0,
|
| 106 |
+
"exact_ref_count": int(row.get("exact_ref_count", 0)),
|
| 107 |
+
"token_acc_mean": float(row.get("token_acc_mean", 0.0)),
|
| 108 |
+
})
|
| 109 |
+
rows.sort(key=lambda x: (x["context_len"], float(x["head_weight_decay"]), x["mode"], x["step"]))
|
| 110 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 111 |
+
with out.open("w", newline="", encoding="utf-8") as f:
|
| 112 |
+
fieldnames = [
|
| 113 |
+
"context_len",
|
| 114 |
+
"head_weight_decay",
|
| 115 |
+
"mode",
|
| 116 |
+
"step",
|
| 117 |
+
"train_views_seen",
|
| 118 |
+
"train_tokens_seen",
|
| 119 |
+
"hit_ratio",
|
| 120 |
+
"exact_count",
|
| 121 |
+
"ref_coverage",
|
| 122 |
+
"exact_ref_count",
|
| 123 |
+
"token_acc_mean",
|
| 124 |
+
]
|
| 125 |
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
| 126 |
+
writer.writeheader()
|
| 127 |
+
writer.writerows(rows)
|
| 128 |
+
print(f"curve_rows={len(rows)}")
|
| 129 |
+
PY
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
eval_latest() {
|
| 133 |
+
local run_name="$1"
|
| 134 |
+
local target_step="$2"
|
| 135 |
+
local head_wd="$3"
|
| 136 |
+
local cache="$4"
|
| 137 |
+
local out_dir="${OUT_ROOT}/len${MAX_LEN}/headwd_$(tag_float "${head_wd}")/${run_name}/step_${target_step}"
|
| 138 |
+
mkdir -p "${out_dir}"
|
| 139 |
+
CUDA_VISIBLE_DEVICES="${EVAL_CUDA_VISIBLE_DEVICES:-0}" python scripts/eval_train8_decode_acc.py \
|
| 140 |
+
--runs_glob "runs/${run_name}" \
|
| 141 |
+
--data_dir "${cache}" \
|
| 142 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 143 |
+
--out_dir "${out_dir}" \
|
| 144 |
+
--max_len "${MAX_LEN}" \
|
| 145 |
+
--n_samples "${N_SAMPLES}" \
|
| 146 |
+
--batch_size "${N_SAMPLES}" \
|
| 147 |
+
--latest_only \
|
| 148 |
+
--endpoint_softenings none \
|
| 149 |
+
--steps "${INFER_STEPS}" \
|
| 150 |
+
--decode_rule flowmap \
|
| 151 |
+
--time_schedule logit_normal \
|
| 152 |
+
--time_logit_mean -1.5 \
|
| 153 |
+
--time_logit_std 0.8 \
|
| 154 |
+
--model_t_mode post \
|
| 155 |
+
--c_min 1 \
|
| 156 |
+
--c_max 512 \
|
| 157 |
+
--late_temp 1.0 \
|
| 158 |
+
--final_from state \
|
| 159 |
+
--final_decode argmax
|
| 160 |
+
python - "$out_dir" "$N_SAMPLES" "$GLOBAL_BATCH_SIZE" "$MAX_LEN" "$SUMMARY_JSONL" "$STOP_EXACT_COUNT" "$STOP_EXACT_REF_COUNT" "$head_wd" <<'PY'
|
| 161 |
+
import json
|
| 162 |
+
import sys
|
| 163 |
+
from pathlib import Path
|
| 164 |
+
|
| 165 |
+
out = Path(sys.argv[1])
|
| 166 |
+
n = int(sys.argv[2])
|
| 167 |
+
global_batch = int(sys.argv[3])
|
| 168 |
+
max_len = int(sys.argv[4])
|
| 169 |
+
summary = Path(sys.argv[5])
|
| 170 |
+
stop_exact_count = int(sys.argv[6])
|
| 171 |
+
stop_exact_ref_count = int(sys.argv[7])
|
| 172 |
+
head_wd = float(sys.argv[8])
|
| 173 |
+
row = json.loads((out / "decode_token_acc.jsonl").read_text().splitlines()[-1])
|
| 174 |
+
row["head_weight_decay"] = head_wd
|
| 175 |
+
row["train_views_seen"] = int(row["ckpt_step"]) * global_batch
|
| 176 |
+
row["train_tokens_seen"] = row["train_views_seen"] * max_len
|
| 177 |
+
line = (
|
| 178 |
+
"RESULT "
|
| 179 |
+
f"len={max_len} "
|
| 180 |
+
f"head_wd={head_wd:g} "
|
| 181 |
+
f"run={row['run']} "
|
| 182 |
+
f"ckpt_step={row['ckpt_step']} "
|
| 183 |
+
f"train_views_seen={row['train_views_seen']} "
|
| 184 |
+
f"train_tokens_seen={row['train_tokens_seen']} "
|
| 185 |
+
f"token_acc_mean={row['token_acc_mean']:.4f} "
|
| 186 |
+
f"exact_count={row['exact_count']}/{n} "
|
| 187 |
+
f"exact_ref_count={row['exact_ref_count']} "
|
| 188 |
+
f"exact_refs={row['exact_ref_hits']}"
|
| 189 |
+
)
|
| 190 |
+
print(line, flush=True)
|
| 191 |
+
summary.parent.mkdir(parents=True, exist_ok=True)
|
| 192 |
+
with summary.open("a", encoding="utf-8") as f:
|
| 193 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 194 |
+
raise SystemExit(
|
| 195 |
+
0
|
| 196 |
+
if int(row["exact_count"]) >= stop_exact_count
|
| 197 |
+
and int(row["exact_ref_count"]) >= stop_exact_ref_count
|
| 198 |
+
else 1
|
| 199 |
+
)
|
| 200 |
+
PY
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
cache="${BASE_CACHE}/gpt2_len${MAX_LEN}_train8_compact_overfit"
|
| 204 |
+
vocab_size="$(vocab_size_for_len)"
|
| 205 |
+
|
| 206 |
+
mode_specs=(
|
| 207 |
+
"onehot|1|0|1,2,4"
|
| 208 |
+
"allcorrupt|0|1|0"
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
for head_wd in ${HEAD_WD_VALUES}; do
|
| 212 |
+
head_tag="$(tag_float "${head_wd}")"
|
| 213 |
+
echo "[headwd-sweep] head_wd=${head_wd} $(date)" | tee -a "${DRIVER_LOG}"
|
| 214 |
+
for mode_spec in "${mode_specs[@]}"; do
|
| 215 |
+
IFS='|' read -r mode lowk_prob all_prob lowk_tokens <<<"${mode_spec}"
|
| 216 |
+
run_name="train8_n${MAX_LEN}_compactv${vocab_size}_small_headwd${head_tag}_bs512_hard_ce_${mode}"
|
| 217 |
+
done_flag="${OUT_ROOT}/len${MAX_LEN}/headwd_${head_tag}/${run_name}/DONE"
|
| 218 |
+
while :; do
|
| 219 |
+
if [[ -f "${done_flag}" ]]; then
|
| 220 |
+
echo "[headwd-sweep] skip done head_wd=${head_wd} mode=${mode} run=${run_name}" | tee -a "${DRIVER_LOG}"
|
| 221 |
+
break
|
| 222 |
+
fi
|
| 223 |
+
|
| 224 |
+
step_now="$(latest_step "${run_name}")"
|
| 225 |
+
if [[ "${step_now}" -ge "${MAX_TOTAL_STEPS}" ]]; then
|
| 226 |
+
echo "[headwd-sweep] hit max_total_steps head_wd=${head_wd} mode=${mode} step=${step_now}" | tee -a "${DRIVER_LOG}"
|
| 227 |
+
break
|
| 228 |
+
fi
|
| 229 |
+
target_step=$((step_now + STEP_CHUNK))
|
| 230 |
+
if [[ "${target_step}" -gt "${MAX_TOTAL_STEPS}" ]]; then
|
| 231 |
+
target_step="${MAX_TOTAL_STEPS}"
|
| 232 |
+
fi
|
| 233 |
+
|
| 234 |
+
resume_path=""
|
| 235 |
+
if [[ -f "runs/${run_name}/latest.pt" ]]; then
|
| 236 |
+
resume_path="runs/${run_name}/latest.pt"
|
| 237 |
+
fi
|
| 238 |
+
|
| 239 |
+
echo "[headwd-sweep] train head_wd=${head_wd} mode=${mode} run=${run_name} from=${step_now} to=${target_step}" | tee -a "${DRIVER_LOG}"
|
| 240 |
+
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}" \
|
| 241 |
+
NPROC_PER_NODE="${NPROC_PER_NODE:-4}" \
|
| 242 |
+
MASTER_PORT=$((MASTER_PORT_BASE + target_step / STEP_CHUNK)) \
|
| 243 |
+
OWT_CHUNK_CACHE_DIR="${cache}" \
|
| 244 |
+
OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK}" \
|
| 245 |
+
MAX_LEN="${MAX_LEN}" \
|
| 246 |
+
VOCAB_SIZE_OVERRIDE="${vocab_size}" \
|
| 247 |
+
D_MODEL="${D_MODEL:-192}" \
|
| 248 |
+
COND_DIM="${COND_DIM:-64}" \
|
| 249 |
+
N_LAYERS="${N_LAYERS:-3}" \
|
| 250 |
+
N_HEADS="${N_HEADS:-3}" \
|
| 251 |
+
DIM_FF="${DIM_FF:-768}" \
|
| 252 |
+
LOWK_CLEAN_TOKENS="${lowk_tokens}" \
|
| 253 |
+
MASK_MIXTURE_LOWK_PROB="${lowk_prob}" \
|
| 254 |
+
MASK_MIXTURE_ALL_PROB="${all_prob}" \
|
| 255 |
+
TOTAL_STEPS="${target_step}" \
|
| 256 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE}" \
|
| 257 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE}" \
|
| 258 |
+
NUM_WORKERS="${NUM_WORKERS:-0}" \
|
| 259 |
+
LOG_EVERY="${LOG_EVERY:-50}" \
|
| 260 |
+
SAVE_EVERY="${STEP_CHUNK}" \
|
| 261 |
+
LATEST_EVERY="${STEP_CHUNK}" \
|
| 262 |
+
WARMUP_STEPS="${WARMUP_STEPS:-10}" \
|
| 263 |
+
LEARNING_RATE="${LEARNING_RATE}" \
|
| 264 |
+
WEIGHT_DECAY="${WEIGHT_DECAY}" \
|
| 265 |
+
OUTPUT_WEIGHT_DECAY="${head_wd}" \
|
| 266 |
+
MUON_IMPL="${MUON_IMPL:-optax}" \
|
| 267 |
+
TARGET_LOSS="hard_ce" \
|
| 268 |
+
CLEAN_STATE_MODE="onehot" \
|
| 269 |
+
RUN_NAME="${run_name}" \
|
| 270 |
+
RESUME_PATH="${resume_path}" \
|
| 271 |
+
bash scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh
|
| 272 |
+
|
| 273 |
+
echo "[headwd-sweep] eval head_wd=${head_wd} mode=${mode} run=${run_name} step=${target_step}" | tee -a "${DRIVER_LOG}"
|
| 274 |
+
if eval_latest "${run_name}" "${target_step}" "${head_wd}" "${cache}" | tee -a "${DRIVER_LOG}"; then
|
| 275 |
+
echo "[headwd-sweep] DONE head_wd=${head_wd} mode=${mode} run=${run_name} step=${target_step}" | tee -a "${DRIVER_LOG}"
|
| 276 |
+
touch "${done_flag}"
|
| 277 |
+
refresh_curve | tee -a "${DRIVER_LOG}"
|
| 278 |
+
break
|
| 279 |
+
else
|
| 280 |
+
echo "[headwd-sweep] continue head_wd=${head_wd} mode=${mode} run=${run_name} step=${target_step}" | tee -a "${DRIVER_LOG}"
|
| 281 |
+
refresh_curve | tee -a "${DRIVER_LOG}"
|
| 282 |
+
fi
|
| 283 |
+
done
|
| 284 |
+
done
|
| 285 |
+
done
|
| 286 |
+
|
| 287 |
+
refresh_curve | tee -a "${DRIVER_LOG}"
|
LTA_openwebtext_dualt/scripts/run_train8_noise_geometry_pilots_4gpu.sh
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /e2e-data/evad-tech-vla/wanghan58/workspace/LTA_openwebtext_dualt
|
| 5 |
+
|
| 6 |
+
export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}"
|
| 7 |
+
export TOKENIZERS_PARALLELISM=false
|
| 8 |
+
export PYTHONUNBUFFERED=1
|
| 9 |
+
|
| 10 |
+
BASE_CACHE="${BASE_CACHE:-/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks}"
|
| 11 |
+
TOKENIZER_PATH="${TOKENIZER_PATH:-/e2e-data/evad-tech-vla/wanghan58/models/flowtext_scorers/gpt2-standard/tokenizer.json}"
|
| 12 |
+
MAX_LEN="${MAX_LEN:-256}"
|
| 13 |
+
N_SAMPLES="${N_SAMPLES:-64}"
|
| 14 |
+
INFER_STEPS="${INFER_STEPS:-128}"
|
| 15 |
+
STEP_CHUNK="${STEP_CHUNK:-1000}"
|
| 16 |
+
MAX_TOTAL_STEPS="${MAX_TOTAL_STEPS:-30000}"
|
| 17 |
+
STOP_EXACT_COUNT="${STOP_EXACT_COUNT:-64}"
|
| 18 |
+
STOP_EXACT_REF_COUNT="${STOP_EXACT_REF_COUNT:-8}"
|
| 19 |
+
MASTER_PORT_BASE="${MASTER_PORT_BASE:-32950}"
|
| 20 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE:-128}"
|
| 21 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-512}"
|
| 22 |
+
GROUP_STAMP="${GROUP_STAMP:-$(date +%Y%m%d_%H%M%S)}"
|
| 23 |
+
OUT_ROOT="${OUT_ROOT:-docs/lta_samples/metrics_20260517/noise_geometry_pilots_len${MAX_LEN}_bs512_ode128_${GROUP_STAMP}}"
|
| 24 |
+
DRIVER_LOG="${DRIVER_LOG:-logs/noise_geometry_pilots_4gpu/${GROUP_STAMP}.log}"
|
| 25 |
+
SUMMARY_JSONL="${SUMMARY_JSONL:-${OUT_ROOT}/summary_results.jsonl}"
|
| 26 |
+
CURVE_CSV="${CURVE_CSV:-${OUT_ROOT}/hit_ratio_curve.csv}"
|
| 27 |
+
mkdir -p "$(dirname "${DRIVER_LOG}")" "${OUT_ROOT}"
|
| 28 |
+
|
| 29 |
+
cache="${BASE_CACHE}/gpt2_len${MAX_LEN}_train8_compact_overfit"
|
| 30 |
+
vocab_size="$(
|
| 31 |
+
python - "$cache" <<'PY'
|
| 32 |
+
import json
|
| 33 |
+
import sys
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
meta = json.loads((Path(sys.argv[1]) / "meta.json").read_text())
|
| 36 |
+
print(int(meta.get("compact_vocab_size", meta.get("vocab_size"))))
|
| 37 |
+
PY
|
| 38 |
+
)"
|
| 39 |
+
|
| 40 |
+
if [[ ! -f "${CURVE_CSV}" ]]; then
|
| 41 |
+
echo "config,ckpt_step,train_views_seen,train_tokens_seen,token_acc_mean,exact_count,exact_ref_count,exact_ref_hits" > "${CURVE_CSV}"
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
latest_step() {
|
| 45 |
+
local run_name="$1"
|
| 46 |
+
python - "$run_name" <<'PY'
|
| 47 |
+
import re
|
| 48 |
+
import sys
|
| 49 |
+
from pathlib import Path
|
| 50 |
+
run = Path("runs") / sys.argv[1]
|
| 51 |
+
steps = []
|
| 52 |
+
for path in run.glob("step_*.pt"):
|
| 53 |
+
m = re.search(r"step_(\d+)\.pt$", path.name)
|
| 54 |
+
if m:
|
| 55 |
+
steps.append(int(m.group(1)))
|
| 56 |
+
print(max(steps) if steps else 0)
|
| 57 |
+
PY
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
free_port() {
|
| 61 |
+
python - <<'PY'
|
| 62 |
+
import socket
|
| 63 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 64 |
+
s.bind(("127.0.0.1", 0))
|
| 65 |
+
print(s.getsockname()[1])
|
| 66 |
+
PY
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
eval_latest() {
|
| 70 |
+
local config="$1"
|
| 71 |
+
local run_name="$2"
|
| 72 |
+
local target_step="$3"
|
| 73 |
+
local out_dir="${OUT_ROOT}/${config}/step_${target_step}"
|
| 74 |
+
mkdir -p "${out_dir}"
|
| 75 |
+
CUDA_VISIBLE_DEVICES="${EVAL_CUDA_VISIBLE_DEVICES:-0}" python scripts/eval_train8_decode_acc.py \
|
| 76 |
+
--runs_glob "runs/${run_name}" \
|
| 77 |
+
--data_dir "${cache}" \
|
| 78 |
+
--tokenizer_path "${TOKENIZER_PATH}" \
|
| 79 |
+
--out_dir "${out_dir}" \
|
| 80 |
+
--max_len "${MAX_LEN}" \
|
| 81 |
+
--n_samples "${N_SAMPLES}" \
|
| 82 |
+
--batch_size "${N_SAMPLES}" \
|
| 83 |
+
--latest_only \
|
| 84 |
+
--endpoint_softenings none \
|
| 85 |
+
--steps "${INFER_STEPS}" \
|
| 86 |
+
--decode_rule flowmap \
|
| 87 |
+
--time_schedule logit_normal \
|
| 88 |
+
--time_logit_mean -1.5 \
|
| 89 |
+
--time_logit_std 0.8 \
|
| 90 |
+
--model_t_mode post \
|
| 91 |
+
--c_min 1 \
|
| 92 |
+
--c_max 512 \
|
| 93 |
+
--late_temp 1.0 \
|
| 94 |
+
--final_from state \
|
| 95 |
+
--final_decode argmax
|
| 96 |
+
python - "$config" "$out_dir" "$N_SAMPLES" "$GLOBAL_BATCH_SIZE" "$MAX_LEN" "$SUMMARY_JSONL" "$CURVE_CSV" "$STOP_EXACT_COUNT" "$STOP_EXACT_REF_COUNT" <<'PY'
|
| 97 |
+
import json
|
| 98 |
+
import sys
|
| 99 |
+
from pathlib import Path
|
| 100 |
+
|
| 101 |
+
config = sys.argv[1]
|
| 102 |
+
out = Path(sys.argv[2])
|
| 103 |
+
n = int(sys.argv[3])
|
| 104 |
+
global_batch = int(sys.argv[4])
|
| 105 |
+
max_len = int(sys.argv[5])
|
| 106 |
+
summary = Path(sys.argv[6])
|
| 107 |
+
curve = Path(sys.argv[7])
|
| 108 |
+
stop_exact_count = int(sys.argv[8])
|
| 109 |
+
stop_exact_ref_count = int(sys.argv[9])
|
| 110 |
+
row = json.loads((out / "decode_token_acc.jsonl").read_text().splitlines()[-1])
|
| 111 |
+
row["config"] = config
|
| 112 |
+
row["train_views_seen"] = int(row["ckpt_step"]) * global_batch
|
| 113 |
+
row["train_tokens_seen"] = row["train_views_seen"] * max_len
|
| 114 |
+
line = (
|
| 115 |
+
"RESULT "
|
| 116 |
+
f"config={config} "
|
| 117 |
+
f"ckpt_step={row['ckpt_step']} "
|
| 118 |
+
f"views={row['train_views_seen']} "
|
| 119 |
+
f"token_acc={row['token_acc_mean']:.4f} "
|
| 120 |
+
f"exact={row['exact_count']}/{n} "
|
| 121 |
+
f"exact_refs={row['exact_ref_count']} "
|
| 122 |
+
f"hits={row['exact_ref_hits']}"
|
| 123 |
+
)
|
| 124 |
+
print(line, flush=True)
|
| 125 |
+
summary.parent.mkdir(parents=True, exist_ok=True)
|
| 126 |
+
with summary.open("a", encoding="utf-8") as f:
|
| 127 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 128 |
+
with curve.open("a", encoding="utf-8") as f:
|
| 129 |
+
f.write(
|
| 130 |
+
f"{config},{row['ckpt_step']},{row['train_views_seen']},{row['train_tokens_seen']},"
|
| 131 |
+
f"{row['token_acc_mean']},{row['exact_count']},{row['exact_ref_count']},"
|
| 132 |
+
f"\"{row['exact_ref_hits']}\"\n"
|
| 133 |
+
)
|
| 134 |
+
raise SystemExit(
|
| 135 |
+
0
|
| 136 |
+
if int(row["exact_count"]) >= stop_exact_count
|
| 137 |
+
and int(row["exact_ref_count"]) >= stop_exact_ref_count
|
| 138 |
+
else 1
|
| 139 |
+
)
|
| 140 |
+
PY
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
reset_noise_defaults() {
|
| 144 |
+
export MIN_MASK_RATIO=1.0
|
| 145 |
+
export MAX_MASK_RATIO=1.0
|
| 146 |
+
export MASK_MIXTURE_LOWK_PROB=0.0
|
| 147 |
+
export MASK_MIXTURE_ALL_PROB=1.0
|
| 148 |
+
export LOWK_CLEAN_TOKENS=0
|
| 149 |
+
export CLEAN_STATE_MODE=onehot
|
| 150 |
+
export TARGET_LOSS=hard_ce
|
| 151 |
+
export LINEAR_SOFT_TARGET_POWER=1.0
|
| 152 |
+
export LINEAR_SOFT_TARGET_MIN_CONF=0.0
|
| 153 |
+
export LINEAR_SOFT_TARGET_MAX_CONF=1.0
|
| 154 |
+
export DIRICHLET_CONCENTRATION_MIN=1.0
|
| 155 |
+
export DIRICHLET_CONCENTRATION_MAX=1024
|
| 156 |
+
export SIMPLEX_BRIDGE_SAMPLER=dirichlet
|
| 157 |
+
export ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA=0.0
|
| 158 |
+
export CATEGORICAL_WRONG_BASIN_TOKEN_IDS=""
|
| 159 |
+
export CATEGORICAL_WRONG_BASIN_PROB=0.0
|
| 160 |
+
export CATEGORICAL_WRONG_UNIGRAM_PROB=0.0
|
| 161 |
+
export CATEGORICAL_WRONG_UNIFORM_PROB=0.0
|
| 162 |
+
export CATEGORICAL_WRONG_BASIN_SHARED_PROB=0.0
|
| 163 |
+
export CATEGORICAL_WRONG_UNIGRAM_SHARED_PROB=0.0
|
| 164 |
+
export LOGISTIC_NORMAL_SIGMA_MIN=0.1
|
| 165 |
+
export LOGISTIC_NORMAL_SIGMA_MAX=1.0
|
| 166 |
+
export LOGISTIC_NORMAL_TAU_MIN=1.0
|
| 167 |
+
export LOGISTIC_NORMAL_TAU_MAX=1.0
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
configure_noise() {
|
| 171 |
+
local config="$1"
|
| 172 |
+
reset_noise_defaults
|
| 173 |
+
case "${config}" in
|
| 174 |
+
allcorrupt_fullvocab_dirC1_1024)
|
| 175 |
+
;;
|
| 176 |
+
allcorrupt_seqrand1_dirC1_1024)
|
| 177 |
+
export ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA=1.0
|
| 178 |
+
;;
|
| 179 |
+
allcorrupt_highC64_4096)
|
| 180 |
+
export DIRICHLET_CONCENTRATION_MIN=64
|
| 181 |
+
export DIRICHLET_CONCENTRATION_MAX=4096
|
| 182 |
+
;;
|
| 183 |
+
allcorrupt_logistic_sig0p05_0p5)
|
| 184 |
+
export SIMPLEX_BRIDGE_SAMPLER=logistic_normal_linear_mean
|
| 185 |
+
export LOGISTIC_NORMAL_SIGMA_MIN=0.05
|
| 186 |
+
export LOGISTIC_NORMAL_SIGMA_MAX=0.5
|
| 187 |
+
;;
|
| 188 |
+
allcorrupt_seqrand1_highC64_4096)
|
| 189 |
+
export ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA=1.0
|
| 190 |
+
export DIRICHLET_CONCENTRATION_MIN=64
|
| 191 |
+
export DIRICHLET_CONCENTRATION_MAX=4096
|
| 192 |
+
;;
|
| 193 |
+
allcorrupt_unigram_shared0p5)
|
| 194 |
+
export CATEGORICAL_WRONG_UNIGRAM_PROB=1.0
|
| 195 |
+
export CATEGORICAL_WRONG_UNIGRAM_SHARED_PROB=0.5
|
| 196 |
+
;;
|
| 197 |
+
*)
|
| 198 |
+
echo "unknown config: ${config}" >&2
|
| 199 |
+
return 2
|
| 200 |
+
;;
|
| 201 |
+
esac
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
configs=(
|
| 205 |
+
allcorrupt_fullvocab_dirC1_1024
|
| 206 |
+
allcorrupt_seqrand1_dirC1_1024
|
| 207 |
+
allcorrupt_highC64_4096
|
| 208 |
+
allcorrupt_logistic_sig0p05_0p5
|
| 209 |
+
allcorrupt_seqrand1_highC64_4096
|
| 210 |
+
allcorrupt_unigram_shared0p5
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
echo "[noise-pilot] start stamp=${GROUP_STAMP} len=${MAX_LEN} vocab=${vocab_size} out=${OUT_ROOT}" | tee -a "${DRIVER_LOG}"
|
| 214 |
+
|
| 215 |
+
round_idx=0
|
| 216 |
+
while :; do
|
| 217 |
+
round_idx=$((round_idx + 1))
|
| 218 |
+
active=0
|
| 219 |
+
echo "[noise-pilot] round=${round_idx} $(date)" | tee -a "${DRIVER_LOG}"
|
| 220 |
+
|
| 221 |
+
idx=0
|
| 222 |
+
for config in "${configs[@]}"; do
|
| 223 |
+
idx=$((idx + 1))
|
| 224 |
+
run_name="train8_noisegeo_len${MAX_LEN}_${config}_${GROUP_STAMP}"
|
| 225 |
+
done_flag="${OUT_ROOT}/${config}/DONE"
|
| 226 |
+
if [[ -f "${done_flag}" ]]; then
|
| 227 |
+
echo "[noise-pilot] skip done config=${config}" | tee -a "${DRIVER_LOG}"
|
| 228 |
+
continue
|
| 229 |
+
fi
|
| 230 |
+
step_now="$(latest_step "${run_name}")"
|
| 231 |
+
if [[ "${step_now}" -ge "${MAX_TOTAL_STEPS}" ]]; then
|
| 232 |
+
echo "[noise-pilot] capped config=${config} step=${step_now}" | tee -a "${DRIVER_LOG}"
|
| 233 |
+
continue
|
| 234 |
+
fi
|
| 235 |
+
active=1
|
| 236 |
+
target_step=$((step_now + STEP_CHUNK))
|
| 237 |
+
if [[ "${target_step}" -gt "${MAX_TOTAL_STEPS}" ]]; then
|
| 238 |
+
target_step="${MAX_TOTAL_STEPS}"
|
| 239 |
+
fi
|
| 240 |
+
resume_path=""
|
| 241 |
+
if [[ -f "runs/${run_name}/latest.pt" ]]; then
|
| 242 |
+
resume_path="runs/${run_name}/latest.pt"
|
| 243 |
+
fi
|
| 244 |
+
|
| 245 |
+
configure_noise "${config}"
|
| 246 |
+
echo "[noise-pilot] train config=${config} from=${step_now} to=${target_step} C=${DIRICHLET_CONCENTRATION_MIN}->${DIRICHLET_CONCENTRATION_MAX} seq_alpha=${ENDPOINT_SEQUENCE_RANDOM_PROB_ALPHA} sampler=${SIMPLEX_BRIDGE_SAMPLER}" | tee -a "${DRIVER_LOG}"
|
| 247 |
+
|
| 248 |
+
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}" \
|
| 249 |
+
NPROC_PER_NODE="${NPROC_PER_NODE:-4}" \
|
| 250 |
+
MASTER_PORT="$(free_port)" \
|
| 251 |
+
OWT_CHUNK_CACHE_DIR="${cache}" \
|
| 252 |
+
OWT_EXACT_REPEAT_PER_CHUNK="${OWT_EXACT_REPEAT_PER_CHUNK:-64}" \
|
| 253 |
+
MAX_LEN="${MAX_LEN}" \
|
| 254 |
+
VOCAB_SIZE_OVERRIDE="${vocab_size}" \
|
| 255 |
+
D_MODEL="${D_MODEL:-192}" \
|
| 256 |
+
COND_DIM="${COND_DIM:-64}" \
|
| 257 |
+
N_LAYERS="${N_LAYERS:-3}" \
|
| 258 |
+
N_HEADS="${N_HEADS:-3}" \
|
| 259 |
+
DIM_FF="${DIM_FF:-768}" \
|
| 260 |
+
TOTAL_STEPS="${target_step}" \
|
| 261 |
+
PER_GPU_BATCH_SIZE="${PER_GPU_BATCH_SIZE}" \
|
| 262 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE}" \
|
| 263 |
+
NUM_WORKERS="${NUM_WORKERS:-0}" \
|
| 264 |
+
LOG_EVERY="${LOG_EVERY:-100}" \
|
| 265 |
+
SAVE_EVERY="${STEP_CHUNK}" \
|
| 266 |
+
LATEST_EVERY="${STEP_CHUNK}" \
|
| 267 |
+
WARMUP_STEPS="${WARMUP_STEPS:-10}" \
|
| 268 |
+
LEARNING_RATE="${LEARNING_RATE:-0.002}" \
|
| 269 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}" \
|
| 270 |
+
MUON_IMPL="${MUON_IMPL:-legacy}" \
|
| 271 |
+
OUTPUT_WEIGHT_DECAY="${OUTPUT_WEIGHT_DECAY:--1}" \
|
| 272 |
+
RUN_NAME="${run_name}" \
|
| 273 |
+
RESUME_PATH="${resume_path}" \
|
| 274 |
+
bash scripts/launch_lta_owt_gpt2_softendpoint_mn_pilot_4gpu.sh
|
| 275 |
+
|
| 276 |
+
echo "[noise-pilot] eval config=${config} step=${target_step}" | tee -a "${DRIVER_LOG}"
|
| 277 |
+
if eval_latest "${config}" "${run_name}" "${target_step}" | tee -a "${DRIVER_LOG}"; then
|
| 278 |
+
echo "[noise-pilot] DONE config=${config} step=${target_step}" | tee -a "${DRIVER_LOG}"
|
| 279 |
+
touch "${done_flag}"
|
| 280 |
+
else
|
| 281 |
+
echo "[noise-pilot] continue config=${config} step=${target_step}" | tee -a "${DRIVER_LOG}"
|
| 282 |
+
fi
|
| 283 |
+
done
|
| 284 |
+
|
| 285 |
+
if [[ "${active}" -eq 0 ]]; then
|
| 286 |
+
echo "[noise-pilot] all configs done/capped $(date)" | tee -a "${DRIVER_LOG}"
|
| 287 |
+
break
|
| 288 |
+
fi
|
| 289 |
+
done
|
LTA_openwebtext_dualt/scripts/sweep_categorical_c1024_diffusion_finalsample_20260506.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import importlib.util
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
BASE_PATH = Path(__file__).with_name("sweep_categorical_c1024_rolling_noise_20260506.py")
|
| 10 |
+
spec = importlib.util.spec_from_file_location("sweep_categorical_c1024_rolling_noise_20260506", BASE_PATH)
|
| 11 |
+
if spec is None or spec.loader is None:
|
| 12 |
+
raise RuntimeError(f"Could not load {BASE_PATH}")
|
| 13 |
+
sweep = importlib.util.module_from_spec(spec)
|
| 14 |
+
sys.modules[spec.name] = sweep
|
| 15 |
+
spec.loader.exec_module(sweep)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def finalsample_configs():
|
| 19 |
+
C = sweep.DecodeCfg
|
| 20 |
+
return [
|
| 21 |
+
C("fs_t1p3_eta1_blend_ft0p35", endpoint_temp=1.3, final_sample=True, final_temp=0.35),
|
| 22 |
+
C("fs_t1p3_eta1_blend_ft0p50", endpoint_temp=1.3, final_sample=True, final_temp=0.50),
|
| 23 |
+
C("fs_t1p3_eta1_blend_ft0p70", endpoint_temp=1.3, final_sample=True, final_temp=0.70),
|
| 24 |
+
C("fs_t1p3_eta1_blend_ft1p00", endpoint_temp=1.3, final_sample=True, final_temp=1.00),
|
| 25 |
+
C("fs_t1p3_eta1_blend_ft1p30", endpoint_temp=1.3, final_sample=True, final_temp=1.30),
|
| 26 |
+
C("fs_t1p3_eta1_blend_ft1p60", endpoint_temp=1.3, final_sample=True, final_temp=1.60),
|
| 27 |
+
C("fs_t1p2_eta1_blend_ft0p50", endpoint_temp=1.2, final_sample=True, final_temp=0.50),
|
| 28 |
+
C("fs_t1p2_eta1_blend_ft0p70", endpoint_temp=1.2, final_sample=True, final_temp=0.70),
|
| 29 |
+
C("fs_t1p2_eta1_blend_ft1p00", endpoint_temp=1.2, final_sample=True, final_temp=1.00),
|
| 30 |
+
C("fs_t1p1_eta1_blend_ft0p70", endpoint_temp=1.1, final_sample=True, final_temp=0.70),
|
| 31 |
+
C("fs_t1p3_eta1_state_ft0p35", endpoint_temp=1.3, final_from="state", final_sample=True, final_temp=0.35),
|
| 32 |
+
C("fs_t1p3_eta1_state_ft0p50", endpoint_temp=1.3, final_from="state", final_sample=True, final_temp=0.50),
|
| 33 |
+
C(
|
| 34 |
+
"fs_sp1p15_sem1p7_t1p3_eta1_blend_ft0p50",
|
| 35 |
+
endpoint_temp=1.3,
|
| 36 |
+
support_power=1.15,
|
| 37 |
+
semantic_power=1.7,
|
| 38 |
+
final_sample=True,
|
| 39 |
+
final_temp=0.50,
|
| 40 |
+
),
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
sweep.default_configs = finalsample_configs
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
sweep.main()
|