fuzirui commited on
Commit
92dbe65
·
verified ·
1 Parent(s): 3cb162f

push_to_jobs: --replace-extracted; cache/pip/tmp on bucket

Browse files
.gitignore CHANGED
@@ -1,17 +1,17 @@
1
- .venv/
2
- venv/
3
- data/
4
- __pycache__/
5
- *.pyc
6
- *.pyo
7
- .pytest_cache/
8
- .mypy_cache/
9
- .ruff_cache/
10
- *.egg-info/
11
- dist/
12
- build/
13
- .git/
14
- .cursor/
15
- agent-tools/
16
- *.pt
17
- .DS_Store
 
1
+ .venv/
2
+ venv/
3
+ data/
4
+ __pycache__/
5
+ *.pyc
6
+ *.pyo
7
+ .pytest_cache/
8
+ .mypy_cache/
9
+ .ruff_cache/
10
+ *.egg-info/
11
+ dist/
12
+ build/
13
+ .git/
14
+ .cursor/
15
+ agent-tools/
16
+ *.pt
17
+ .DS_Store
configs/jobs_overrides.yaml CHANGED
@@ -1,20 +1,20 @@
1
- # HF Jobs 训练覆盖项(与 configs/default.yaml 深度合并)。
2
- # 用法: python -m wjad.train.runner_local --config configs/default.yaml --config_overrides configs/jobs_overrides.yaml
3
-
4
- train:
5
- batch_size: 12
6
- grad_accum_steps: 1 # 显存不够可改大累积步并减小 batch
7
-
8
- data:
9
- use_synthetic: true
10
- use_real: true
11
-
12
- deploy:
13
- hf_code_repo: "fuzirui/WJAD"
14
- hf_weights_repo: "fuzirui/WJAD"
15
- hf_bucket_id: "fuzirui/WJAD"
16
- mirror_src_uri: "hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/"
17
- mirror_bucket_subpath: "mirror/cosmos_hub"
18
- extracted_bucket_subpath: "extracted/cosmos_hub"
19
- cache_bucket_subpath: "cache"
20
- runs_bucket_subpath: "runs"
 
1
+ # HF Jobs 训练覆盖项(与 configs/default.yaml 深度合并)。
2
+ # 用法: python -m wjad.train.runner_local --config configs/default.yaml --config_overrides configs/jobs_overrides.yaml
3
+
4
+ train:
5
+ batch_size: 12
6
+ grad_accum_steps: 1 # 显存不够可改大累积步并减小 batch
7
+
8
+ data:
9
+ use_synthetic: true
10
+ use_real: true
11
+
12
+ deploy:
13
+ hf_code_repo: "fuzirui/WJAD"
14
+ hf_weights_repo: "fuzirui/WJAD"
15
+ hf_bucket_id: "fuzirui/WJAD"
16
+ mirror_src_uri: "hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/"
17
+ mirror_bucket_subpath: "mirror/cosmos_hub"
18
+ extracted_bucket_subpath: "extracted/cosmos_hub"
19
+ cache_bucket_subpath: "cache"
20
+ runs_bucket_subpath: "runs"
scripts/clear_wjad_extract_state.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """删除 ``jobs_extract_archives.py`` 在 ``--out-root`` 下生成的 ``_wjad_extract_state`` 目录。
2
+
3
+ 只删「已解压」进度标记,**不会**删除 ``extracted/`` 里真正的数据集文件。
4
+ 删完后可重新跑解压脚本,会按归档从头再解压一遍(若目录已存在可能覆盖/混合,按需先清 extracted)。
5
+
6
+ 用法示例::
7
+
8
+ # 本机 / Bucket 挂载路径与解压时 --out-root 一致
9
+ python scripts/clear_wjad_extract_state.py --out-root /mnt/wjad/extracted/cosmos_hub
10
+
11
+ # 只看会删什么,不真删
12
+ python scripts/clear_wjad_extract_state.py --out-root F:/bucket/extracted/cosmos_hub --dry-run
13
+
14
+ 环境变量(可选)::
15
+
16
+ 设置 WJAD_EXTRACTED_ROOT 后可省略 --out-root
17
+
18
+ Linux Job::
19
+
20
+ export WJAD_EXTRACTED_ROOT=/mnt/wjad/extracted/cosmos_hub
21
+ python scripts/clear_wjad_extract_state.py
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import argparse
27
+ import os
28
+ import shutil
29
+ import sys
30
+ from pathlib import Path
31
+
32
+ STATE_DIRNAME = "_wjad_extract_state"
33
+
34
+
35
+ def main() -> None:
36
+ p = argparse.ArgumentParser(description="Remove _wjad_extract_state under extracted out-root.")
37
+ p.add_argument(
38
+ "--out-root",
39
+ type=Path,
40
+ default=None,
41
+ help="与 jobs_extract_archives --out-root 相同;不设则用环境变量 WJAD_EXTRACTED_ROOT",
42
+ )
43
+ p.add_argument("--dry-run", action="store_true", help="只打印路径,不删除")
44
+ args = p.parse_args()
45
+ out = args.out_root
46
+ if out is None:
47
+ env = os.environ.get("WJAD_EXTRACTED_ROOT")
48
+ if not env:
49
+ print(
50
+ "需要 --out-root 或环境变量 WJAD_EXTRACTED_ROOT",
51
+ file=sys.stderr,
52
+ )
53
+ sys.exit(2)
54
+ out = Path(env)
55
+ out = out.resolve()
56
+ state = out / STATE_DIRNAME
57
+ if not state.exists():
58
+ print(f"[clear] 不存在,跳过: {state}")
59
+ return
60
+ if not state.is_dir():
61
+ print(f"[clear] 不是目录,拒绝: {state}", file=sys.stderr)
62
+ sys.exit(1)
63
+ print(f"[clear] {'(dry-run) ' if args.dry_run else ''}目标: {state}")
64
+ if args.dry_run:
65
+ n = sum(1 for _ in state.rglob("*"))
66
+ print(f"[clear] dry-run: 其下约 {n} 个条目(含文件与目录)")
67
+ return
68
+ shutil.rmtree(state)
69
+ print(f"[clear] 已删除: {state}")
70
+
71
+
72
+ if __name__ == "__main__":
73
+ main()
scripts/jobs_entry_train.sh CHANGED
@@ -1,51 +1,57 @@
1
- #!/usr/bin/env bash
2
- # HF Job(GPU)入口:在仓库根目录执行(由 push_to_jobs 先 clone 再调用本脚本)。
3
- # 数据与缓存路径默认指向已挂载的 Bucket(WJAD_BUCKET_MOUNT),避免占用 Job 本地盘。
4
- set -euo pipefail
5
-
6
- REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
7
- cd "$REPO_ROOT"
8
-
9
- BUCKET="${WJAD_BUCKET_MOUNT:-/mnt/wjad}"
10
- HUB_REPO="${WJAD_HUB_REPO:-fuzirui/WJAD}"
11
- # 缓存统一在 bucket 的 cache/ 下,与镜像(mirror)和解压数据(extracted)分离,可整目录删除清缓存
12
- CACHE_ROOT="${WJAD_CACHE_ROOT:-${BUCKET}/cache}"
13
-
14
- export HF_HOME="${CACHE_ROOT}/hf"
15
- export TRANSFORMERS_CACHE="${CACHE_ROOT}/transformers"
16
- export TORCH_HOME="${CACHE_ROOT}/torch"
17
- export HF_DATASETS_CACHE="${CACHE_ROOT}/datasets"
18
- export XDG_CACHE_HOME="${CACHE_ROOT}/xdg"
19
- mkdir -p "${HF_HOME}" "${TRANSFORMERS_CACHE}" "${TORCH_HOME}" "${HF_DATASETS_CACHE}" "${XDG_CACHE_HOME}"
20
-
21
- if command -v apt-get >/dev/null 2>&1; then
22
- apt-get update && apt-get install -y --no-install-recommends git ffmpeg libgl1 libglib2.0-0 || true
23
- fi
24
-
25
- pip install -q -U pip huggingface_hub
26
- pip install -q -e .
27
-
28
- export WJAD_OUTPUT_DIR="${WJAD_OUTPUT_DIR:-${BUCKET}/runs/current}"
29
- export WJAD_HUB_REPO="${HUB_REPO}"
30
- # 训练只读解压产物,不与 mirror 混用(解压脚本写入 extracted/,保持与源相同的相对路径树)
31
- export WJAD_DATA_ROOT="${WJAD_DATA_ROOT:-${BUCKET}/extracted/cosmos_hub}"
32
- mkdir -p "${WJAD_OUTPUT_DIR}"
33
-
34
- read -r -a BS_TRY <<< "${WJAD_OVERRIDE_BS_LIST:-12 10 8 6 4 2}"
35
- for BS in "${BS_TRY[@]}"; do
36
- echo "[jobs_entry_train] try batch_size=${BS}" >&2
37
- export WJAD_BATCH_SIZE="${BS}"
38
- if python -m wjad.train.runner_local \
39
- --config configs/default.yaml \
40
- --config_overrides configs/jobs_overrides.yaml \
41
- --device cuda \
42
- --data_root "${WJAD_DATA_ROOT}" \
43
- --dinov3_path ./dinov3-vitb16-pretrain-lvd1689m \
44
- --output_dir "${WJAD_OUTPUT_DIR}" \
45
- --hub_repo "${HUB_REPO}"; then
46
- echo "[jobs_entry_train] finished @ BS=${BS}" >&2
47
- exit 0
48
- fi
49
- echo "[jobs_entry_train] run failed @ BS=${BS}, lowering batch" >&2
50
- done
51
- exit 1
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # HF Job(GPU)入口:在仓库根目录执行(由 push_to_jobs 先 clone 再调用本脚本)。
3
+ # 数据与缓存路径默认指向已挂载的 Bucket(WJAD_BUCKET_MOUNT),避免占用 Job 本地盘。
4
+ set -euo pipefail
5
+
6
+ REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
7
+ cd "$REPO_ROOT"
8
+
9
+ BUCKET="${WJAD_BUCKET_MOUNT:-/mnt/wjad}"
10
+ HUB_REPO="${WJAD_HUB_REPO:-fuzirui/WJAD}"
11
+ # 缓存统一在 bucket 的 cache/ 下,与镜像(mirror)和解压数据(extracted)分离,可整目录删除清缓存
12
+ CACHE_ROOT="${WJAD_CACHE_ROOT:-${BUCKET}/cache}"
13
+
14
+ export HF_HOME="${CACHE_ROOT}/hf"
15
+ export TRANSFORMERS_CACHE="${CACHE_ROOT}/transformers"
16
+ export TORCH_HOME="${CACHE_ROOT}/torch"
17
+ export HF_DATASETS_CACHE="${CACHE_ROOT}/datasets"
18
+ export XDG_CACHE_HOME="${CACHE_ROOT}/xdg"
19
+ # pip / 临时文件也进 Bucket,避免 K8s ephemeral 50G 限额被占满
20
+ export PIP_CACHE_DIR="${CACHE_ROOT}/pip"
21
+ export TMPDIR="${CACHE_ROOT}/tmp"
22
+ export TEMP="${TMPDIR}"
23
+ export TMP="${TMPDIR}"
24
+ mkdir -p "${HF_HOME}" "${TRANSFORMERS_CACHE}" "${TORCH_HOME}" "${HF_DATASETS_CACHE}" "${XDG_CACHE_HOME}" \
25
+ "${PIP_CACHE_DIR}" "${TMPDIR}"
26
+
27
+ if command -v apt-get >/dev/null 2>&1; then
28
+ apt-get update && apt-get install -y --no-install-recommends git ffmpeg libgl1 libglib2.0-0 || true
29
+ fi
30
+
31
+ pip install -q -U pip huggingface_hub
32
+ pip install -q -e .
33
+
34
+ export WJAD_OUTPUT_DIR="${WJAD_OUTPUT_DIR:-${BUCKET}/runs/current}"
35
+ export WJAD_HUB_REPO="${HUB_REPO}"
36
+ # 训练只读解压产物,不与 mirror 混用(解压脚本写入 extracted/,保持与源相同的相对路径树)
37
+ export WJAD_DATA_ROOT="${WJAD_DATA_ROOT:-${BUCKET}/extracted/cosmos_hub}"
38
+ mkdir -p "${WJAD_OUTPUT_DIR}"
39
+
40
+ read -r -a BS_TRY <<< "${WJAD_OVERRIDE_BS_LIST:-12 10 8 6 4 2}"
41
+ for BS in "${BS_TRY[@]}"; do
42
+ echo "[jobs_entry_train] try batch_size=${BS}" >&2
43
+ export WJAD_BATCH_SIZE="${BS}"
44
+ if python -m wjad.train.runner_local \
45
+ --config configs/default.yaml \
46
+ --config_overrides configs/jobs_overrides.yaml \
47
+ --device cuda \
48
+ --data_root "${WJAD_DATA_ROOT}" \
49
+ --dinov3_path ./dinov3-vitb16-pretrain-lvd1689m \
50
+ --output_dir "${WJAD_OUTPUT_DIR}" \
51
+ --hub_repo "${HUB_REPO}"; then
52
+ echo "[jobs_entry_train] finished @ BS=${BS}" >&2
53
+ exit 0
54
+ fi
55
+ echo "[jobs_entry_train] run failed @ BS=${BS}, lowering batch" >&2
56
+ done
57
+ exit 1
scripts/jobs_extract_archives.py CHANGED
@@ -1,119 +1,119 @@
1
- """在已挂载的 Bucket 上解压归档:写入单独的 ``extracted/`` 树,不与 ``mirror/`` 混放。
2
-
3
- 扫描 ``--scan-root``(通常为 ``mirror/cosmos_hub``,即 Hub copy 镜像)下的归档;
4
- 解压到 ``--out-root`` / (相对 scan-root 的父路径) / (归档主文件名)/ ,
5
- 从而在 ``extracted/`` 下复现与源数据相同的**相对目录结构**,避免与镜像目录混淆。
6
-
7
- 解压产物只写在 ``--out-root`` 下;**不在 mirror 里落任何文件**(进度标记也放在
8
- ``out-root/_wjad_extract_state/``,与归档相对路径对应,避免修改 ``--scan-root``)。
9
-
10
- 示例(mount ``hf://buckets/.../WJAD`` → ``/mnt/wjad``)::
11
-
12
- python scripts/jobs_extract_archives.py \\
13
- --scan-root /mnt/wjad/mirror/cosmos_hub \\
14
- --out-root /mnt/wjad/extracted/cosmos_hub
15
- """
16
-
17
- from __future__ import annotations
18
-
19
- import argparse
20
- import tarfile
21
- import zipfile
22
- from pathlib import Path
23
-
24
-
25
- def _archive_stem(path: Path) -> str:
26
- n = path.name
27
- lower = n.lower()
28
- for ext in (".tar.gz", ".tar.bz2", ".tgz"):
29
- if lower.endswith(ext):
30
- return n[: -len(ext)]
31
- if lower.endswith(".tar"):
32
- return n[:-4]
33
- if lower.endswith(".zip"):
34
- return n[:-4]
35
- return path.stem
36
-
37
-
38
- def _is_archive(path: Path) -> bool:
39
- lower = path.name.lower()
40
- return lower.endswith(
41
- (".tar.gz", ".tar.bz2", ".tgz", ".tar.xz", ".tar", ".zip"),
42
- )
43
-
44
-
45
- STATE_DIRNAME = "_wjad_extract_state"
46
-
47
-
48
- def _done_marker_path(archive: Path, scan: Path, out_root: Path) -> Path:
49
- """标记只写在 out_root 下,绝不写回 mirror。"""
50
- rel = archive.relative_to(scan)
51
- return out_root / STATE_DIRNAME / rel.parent / (rel.name + ".wjad_done")
52
-
53
-
54
- def _validate_roots(scan: Path, out_root: Path) -> None:
55
- s, o = scan.resolve(), out_root.resolve()
56
- if s == o:
57
- raise SystemExit("--out-root 不能与 --scan-root 相同,否则会写回镜像目录")
58
- try:
59
- o.relative_to(s)
60
- raise SystemExit("--out-root 不能位于 --scan-root 内部(mirror 只读,解压请用 extracted/)")
61
- except ValueError:
62
- pass
63
-
64
-
65
- def _extract_one(archive: Path, dest_dir: Path) -> None:
66
- dest_dir.mkdir(parents=True, exist_ok=True)
67
- lower = archive.name.lower()
68
- if lower.endswith(".zip"):
69
- with zipfile.ZipFile(archive, "r") as z:
70
- z.extractall(dest_dir)
71
- return
72
- mode = "r"
73
- if lower.endswith((".tar.gz", ".tgz")):
74
- mode = "r:gz"
75
- elif lower.endswith(".tar.bz2"):
76
- mode = "r:bz2"
77
- elif lower.endswith(".tar.xz"):
78
- mode = "r:xz"
79
- elif lower.endswith(".tar"):
80
- mode = "r:"
81
- else:
82
- raise ValueError(f"unsupported archive: {archive}")
83
- with tarfile.open(archive, mode) as tf:
84
- try:
85
- tf.extractall(dest_dir, filter=tarfile.data_filter) # py3.12+
86
- except TypeError:
87
- tf.extractall(dest_dir)
88
-
89
-
90
- def main() -> None:
91
- p = argparse.ArgumentParser()
92
- p.add_argument("--scan-root", type=Path, required=True)
93
- p.add_argument("--out-root", type=Path, required=True)
94
- args = p.parse_args()
95
- scan: Path = args.scan_root.resolve()
96
- out_root: Path = args.out_root.resolve()
97
- if not scan.is_dir():
98
- raise SystemExit(f"--scan-root not a directory: {scan}")
99
- _validate_roots(scan, out_root)
100
-
101
- count = 0
102
- for path in sorted(scan.rglob("*")):
103
- if not path.is_file() or not _is_archive(path):
104
- continue
105
- mpath = _done_marker_path(path, scan, out_root)
106
- if mpath.exists():
107
- continue
108
- rel_parent = path.parent.relative_to(scan)
109
- dest = out_root / rel_parent / _archive_stem(path)
110
- print(f"[extract] {path.relative_to(scan)} -> {dest.relative_to(out_root)}", flush=True)
111
- _extract_one(path, dest)
112
- mpath.parent.mkdir(parents=True, exist_ok=True)
113
- mpath.write_text("ok\n", encoding="utf-8")
114
- count += 1
115
- print(f"[extract] done, {count} archives", flush=True)
116
-
117
-
118
- if __name__ == "__main__":
119
- main()
 
1
+ """在已挂载的 Bucket 上解压归档:写入单独的 ``extracted/`` 树,不与 ``mirror/`` 混放。
2
+
3
+ 扫描 ``--scan-root``(通常为 ``mirror/cosmos_hub``,即 Hub copy 镜像)下的归档;
4
+ 解压到 ``--out-root`` / (相对 scan-root 的父路径) / (归档主文件名)/ ,
5
+ 从而在 ``extracted/`` 下复现与源数据相同的**相对目录结构**,避免与镜像目录混淆。
6
+
7
+ 解压产物只写在 ``--out-root`` 下;**不在 mirror 里落任何文件**(进度标记也放在
8
+ ``out-root/_wjad_extract_state/``,与归档相对路径对应,避免修改 ``--scan-root``)。
9
+
10
+ 示例(mount ``hf://buckets/.../WJAD`` → ``/mnt/wjad``)::
11
+
12
+ python scripts/jobs_extract_archives.py \\
13
+ --scan-root /mnt/wjad/mirror/cosmos_hub \\
14
+ --out-root /mnt/wjad/extracted/cosmos_hub
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import tarfile
21
+ import zipfile
22
+ from pathlib import Path
23
+
24
+
25
+ def _archive_stem(path: Path) -> str:
26
+ n = path.name
27
+ lower = n.lower()
28
+ for ext in (".tar.gz", ".tar.bz2", ".tgz"):
29
+ if lower.endswith(ext):
30
+ return n[: -len(ext)]
31
+ if lower.endswith(".tar"):
32
+ return n[:-4]
33
+ if lower.endswith(".zip"):
34
+ return n[:-4]
35
+ return path.stem
36
+
37
+
38
+ def _is_archive(path: Path) -> bool:
39
+ lower = path.name.lower()
40
+ return lower.endswith(
41
+ (".tar.gz", ".tar.bz2", ".tgz", ".tar.xz", ".tar", ".zip"),
42
+ )
43
+
44
+
45
+ STATE_DIRNAME = "_wjad_extract_state"
46
+
47
+
48
+ def _done_marker_path(archive: Path, scan: Path, out_root: Path) -> Path:
49
+ """标记只写在 out_root 下,绝不写回 mirror。"""
50
+ rel = archive.relative_to(scan)
51
+ return out_root / STATE_DIRNAME / rel.parent / (rel.name + ".wjad_done")
52
+
53
+
54
+ def _validate_roots(scan: Path, out_root: Path) -> None:
55
+ s, o = scan.resolve(), out_root.resolve()
56
+ if s == o:
57
+ raise SystemExit("--out-root 不能与 --scan-root 相同,否则会写回镜像目录")
58
+ try:
59
+ o.relative_to(s)
60
+ raise SystemExit("--out-root 不能位于 --scan-root 内部(mirror 只读,解压请用 extracted/)")
61
+ except ValueError:
62
+ pass
63
+
64
+
65
+ def _extract_one(archive: Path, dest_dir: Path) -> None:
66
+ dest_dir.mkdir(parents=True, exist_ok=True)
67
+ lower = archive.name.lower()
68
+ if lower.endswith(".zip"):
69
+ with zipfile.ZipFile(archive, "r") as z:
70
+ z.extractall(dest_dir)
71
+ return
72
+ mode = "r"
73
+ if lower.endswith((".tar.gz", ".tgz")):
74
+ mode = "r:gz"
75
+ elif lower.endswith(".tar.bz2"):
76
+ mode = "r:bz2"
77
+ elif lower.endswith(".tar.xz"):
78
+ mode = "r:xz"
79
+ elif lower.endswith(".tar"):
80
+ mode = "r:"
81
+ else:
82
+ raise ValueError(f"unsupported archive: {archive}")
83
+ with tarfile.open(archive, mode) as tf:
84
+ try:
85
+ tf.extractall(dest_dir, filter=tarfile.data_filter) # py3.12+
86
+ except TypeError:
87
+ tf.extractall(dest_dir)
88
+
89
+
90
+ def main() -> None:
91
+ p = argparse.ArgumentParser()
92
+ p.add_argument("--scan-root", type=Path, required=True)
93
+ p.add_argument("--out-root", type=Path, required=True)
94
+ args = p.parse_args()
95
+ scan: Path = args.scan_root.resolve()
96
+ out_root: Path = args.out_root.resolve()
97
+ if not scan.is_dir():
98
+ raise SystemExit(f"--scan-root not a directory: {scan}")
99
+ _validate_roots(scan, out_root)
100
+
101
+ count = 0
102
+ for path in sorted(scan.rglob("*")):
103
+ if not path.is_file() or not _is_archive(path):
104
+ continue
105
+ mpath = _done_marker_path(path, scan, out_root)
106
+ if mpath.exists():
107
+ continue
108
+ rel_parent = path.parent.relative_to(scan)
109
+ dest = out_root / rel_parent / _archive_stem(path)
110
+ print(f"[extract] {path.relative_to(scan)} -> {dest.relative_to(out_root)}", flush=True)
111
+ _extract_one(path, dest)
112
+ mpath.parent.mkdir(parents=True, exist_ok=True)
113
+ mpath.write_text("ok\n", encoding="utf-8")
114
+ count += 1
115
+ print(f"[extract] done, {count} archives", flush=True)
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()
scripts/jobs_hub_copy_to_bucket.py CHANGED
@@ -1,31 +1,31 @@
1
- """Hub 服务端 copy:把 Hub 上已有 dataset/model/space 树拷贝到 Bucket(大文件走 xet hash,不占 Job 本地带宽)。
2
-
3
- 示例(在 HF Job 内,仅需 HF_TOKEN)::
4
-
5
- python scripts/jobs_hub_copy_to_bucket.py \\
6
- --src hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/ \\
7
- --dst hf://buckets/fuzirui/WJAD/mirror/cosmos_hub/
8
- """
9
-
10
- from __future__ import annotations
11
-
12
- import argparse
13
-
14
- from huggingface_hub import copy_files
15
-
16
-
17
- def main() -> None:
18
- p = argparse.ArgumentParser()
19
- p.add_argument("--src", required=True, help="hf://datasets/... 或 hf://user/model/…")
20
- p.add_argument(
21
- "--dst",
22
- required=True,
23
- help="目标必须为 bucket,如 hf://buckets/fuzirui/WJAD/mirror/cosmos_hub/",
24
- )
25
- args = p.parse_args()
26
- copy_files(args.src, args.dst)
27
- print("[copy] OK", args.src, "->", args.dst)
28
-
29
-
30
- if __name__ == "__main__":
31
- main()
 
1
+ """Hub 服务端 copy:把 Hub 上已有 dataset/model/space 树拷贝到 Bucket(大文件走 xet hash,不占 Job 本地带宽)。
2
+
3
+ 示例(在 HF Job 内,仅需 HF_TOKEN)::
4
+
5
+ python scripts/jobs_hub_copy_to_bucket.py \\
6
+ --src hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/ \\
7
+ --dst hf://buckets/fuzirui/WJAD/mirror/cosmos_hub/
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+
14
+ from huggingface_hub import copy_files
15
+
16
+
17
+ def main() -> None:
18
+ p = argparse.ArgumentParser()
19
+ p.add_argument("--src", required=True, help="hf://datasets/... 或 hf://user/model/…")
20
+ p.add_argument(
21
+ "--dst",
22
+ required=True,
23
+ help="目标必须为 bucket,如 hf://buckets/fuzirui/WJAD/mirror/cosmos_hub/",
24
+ )
25
+ args = p.parse_args()
26
+ copy_files(args.src, args.dst)
27
+ print("[copy] OK", args.src, "->", args.dst)
28
+
29
+
30
+ if __name__ == "__main__":
31
+ main()
scripts/push_to_jobs.py CHANGED
@@ -1,306 +1,339 @@
1
- """提交 Hugging Face Jobs:数据集服务端 copy → CPU 挂载解压 → A10G-Large 正式训练。
2
-
3
- **持久化只在 Bucket 上**(挂载如 ``/mnt/wjad``)。Job 容器本地仅存 ``git clone`` 与 pip
4
- 临时文件大数据放不进本地盘
5
-
6
- Bucket 目录约定(相对挂载根)::
7
-
8
- - ``mirror/cosmos_hub/`` — ``copy_files`` 得到的 Hub 数据集镜像;**只作解压源**,不把解压产物写回这里。
9
- - ``extracted/cosmos_hub/`` — **仅解压输出**,相对路径与源一致,训练 ``--data_root`` 指向这里。
10
- - ``cache/`` — 所有 HF / PyTorch / transformers 缓存根目录,**可整目录删** 以清缓存。
11
- - ``runs/current/`` — checkpoint(仍在 Bucket);另可 ``upload_file`` 到 Hub model repo。
12
-
13
- 常用命令::
14
-
15
- python scripts/push_to_jobs.py copy-extract # 先 copy(等待成功)再 submit extract(前台跑完日志)
16
-
17
- 单独步骤加 ``--detach``。Windows 上会使用 ``sys.executable`` 同目录的 ``hf.exe``。
18
- """
19
-
20
- from __future__ import annotations
21
-
22
- import argparse
23
- import json
24
- import re
25
- import shlex
26
- import shutil
27
- import subprocess
28
- import sys
29
- import time
30
- from pathlib import Path
31
-
32
- DEFAULT_BUCKET = "fuzirui/WJAD"
33
- DEFAULT_CODE = "fuzirui/WJAD"
34
- DEFAULT_WEIGHTS = "fuzirui/WJAD"
35
- DEFAULT_SRC = "hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/"
36
- DEFAULT_MIRROR = "mirror/cosmos_hub"
37
- DEFAULT_EXTRACTED = "extracted/cosmos_hub"
38
- DEFAULT_CACHE = "cache"
39
- DEFAULT_TIMEOUT = "7d"
40
- TRAIN_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-runtime"
41
- TRAIN_FLAVOR = "a10g-large"
42
-
43
-
44
- def _hf_executable() -> str:
45
- sibling = Path(sys.executable).parent / ("hf.exe" if sys.platform == "win32" else "hf")
46
- if sibling.is_file():
47
- return str(sibling.resolve())
48
- w = shutil.which("hf")
49
- return w if w else "hf"
50
-
51
-
52
- def _hf_jobs_cmd(parts: list[str]) -> list[str]:
53
- if parts and parts[0] == "hf":
54
- return [_hf_executable(), *parts[1:]]
55
- return parts
56
-
57
-
58
- def _insert_detach(cmd: list[str], detach: bool) -> None:
59
- if not detach:
60
- return
61
- idx = cmd.index("run") + 1
62
- cmd.insert(idx, "--detach")
63
-
64
-
65
- def build_copy_cmd(args: argparse.Namespace) -> list[str]:
66
- dst = f"hf://buckets/{args.bucket}/{args.mirror_subpath.rstrip('/')}/"
67
- one_py = (
68
- "import subprocess,sys;"
69
- "subprocess.check_call([sys.executable,'-m','pip','install','-q','-U','huggingface_hub']);"
70
- "from huggingface_hub import copy_files;"
71
- f"copy_files({args.src!r},{dst!r})"
72
- )
73
- cmd = [
74
- "hf",
75
- "jobs",
76
- "run",
77
- "--flavor",
78
- "cpu-basic",
79
- "--timeout",
80
- args.timeout,
81
- "--secrets",
82
- "HF_TOKEN",
83
- "python:3.12",
84
- "python",
85
- "-c",
86
- one_py,
87
- ]
88
- _insert_detach(cmd, args.detach)
89
- return cmd
90
-
91
-
92
- def build_extract_cmd(args: argparse.Namespace) -> list[str]:
93
- vol = f"hf://buckets/{args.bucket}:/mnt/wjad"
94
- inner = " && ".join(
95
- [
96
- "set -e",
97
- "command -v git >/dev/null 2>&1 || (apt-get update && apt-get install -y --no-install-recommends git)",
98
- "pip install -q -U huggingface_hub",
99
- f"git clone https://oauth2:$HF_TOKEN@huggingface.co/{args.code_repo} /tmp/wjad",
100
- "cd /tmp/wjad",
101
- "python scripts/jobs_extract_archives.py "
102
- f"--scan-root /mnt/wjad/{args.mirror_subpath} --out-root /mnt/wjad/{args.extracted_subpath}",
103
- ]
104
- )
105
- cmd = [
106
- "hf",
107
- "jobs",
108
- "run",
109
- "-v",
110
- vol,
111
- "--flavor",
112
- "cpu-basic",
113
- "--timeout",
114
- args.timeout,
115
- "--secrets",
116
- "HF_TOKEN",
117
- "python:3.12",
118
- "sh",
119
- "-c",
120
- inner,
121
- ]
122
- _insert_detach(cmd, args.detach)
123
- return cmd
124
-
125
-
126
- def build_train_cmd(args: argparse.Namespace) -> list[str]:
127
- vol = f"hf://buckets/{args.bucket}:/mnt/wjad"
128
- inner = " && ".join(
129
- [
130
- "set -e",
131
- "export WJAD_BUCKET_MOUNT=/mnt/wjad",
132
- f"export WJAD_CACHE_ROOT=/mnt/wjad/{args.cache_subpath}",
133
- f"export WJAD_HUB_REPO={args.weights_repo}",
134
- f"export WJAD_DATA_ROOT=/mnt/wjad/{args.extracted_subpath}",
135
- "export WJAD_OUTPUT_DIR=/mnt/wjad/runs/current",
136
- "command -v git >/dev/null 2>&1 || (apt-get update && apt-get install -y --no-install-recommends git)",
137
- "pip install -q -U huggingface_hub",
138
- f"git clone https://oauth2:$HF_TOKEN@huggingface.co/{args.code_repo} /workspace/wjad",
139
- "cd /workspace/wjad",
140
- "pip install -q -U pip",
141
- "pip install -q -e .",
142
- "bash scripts/jobs_entry_train.sh",
143
- ]
144
- )
145
- cmd = [
146
- "hf",
147
- "jobs",
148
- "run",
149
- "-v",
150
- vol,
151
- "--flavor",
152
- args.train_flavor,
153
- "--timeout",
154
- args.timeout,
155
- "--secrets",
156
- "HF_TOKEN",
157
- args.train_image,
158
- "bash",
159
- "-lc",
160
- inner,
161
- ]
162
- _insert_detach(cmd, args.detach)
163
- return cmd
164
-
165
-
166
- def _run(cmd: list[str]) -> int:
167
- print("[push_to_jobs] $", " ".join(shlex.quote(c) for c in cmd))
168
- cmd = _hf_jobs_cmd(cmd)
169
- return subprocess.call(cmd)
170
-
171
-
172
- def _run_capture(cmd: list[str]) -> subprocess.CompletedProcess:
173
- cmd = _hf_jobs_cmd(cmd)
174
- print("[push_to_jobs] $", " ".join(shlex.quote(c) for c in cmd))
175
- return subprocess.run(cmd, capture_output=True, text=True)
176
-
177
-
178
- def _parse_detach_job_id(stdout: str, stderr: str) -> str | None:
179
- text = (stdout + "\n" + stderr).strip()
180
- m = re.search(r"\bID:\s*([a-fA-F0-9]{12,})\b", text)
181
- if m:
182
- return m.group(1)
183
- m = re.search(r"/jobs/[^/\s]+/([a-fA-F0-9]{12,})", text)
184
- if m:
185
- return m.group(1)
186
- for line in text.splitlines():
187
- line = line.strip()
188
- low = line.lower()
189
- if "job" in low and "id" in low:
190
- parts = line.replace(":", " ").split()
191
- for i, p in enumerate(parts):
192
- if p.lower() == "id" and i + 1 < len(parts):
193
- return parts[i + 1].strip().rstrip(",")
194
- return None
195
-
196
-
197
- def _job_status(job_id: str) -> str | None:
198
- r = _run_capture(["hf", "jobs", "inspect", job_id, "--json"])
199
- if r.returncode != 0 or not r.stdout.strip():
200
- return None
201
- try:
202
- data = json.loads(r.stdout)
203
- row = data[0] if isinstance(data, list) and data else data if isinstance(data, dict) else None
204
- if not isinstance(row, dict):
205
- return None
206
- st = row.get("status")
207
- if isinstance(st, dict):
208
- stage = st.get("stage")
209
- if isinstance(stage, str):
210
- return stage.lower()
211
- if isinstance(st, str):
212
- return st.lower()
213
- except json.JSONDecodeError:
214
- pass
215
- return None
216
-
217
-
218
- def _wait_job(job_id: str, poll_s: float = 45.0, label: str = "") -> int:
219
- print(f"[push_to_jobs] 轮询 Job: {job_id} ({label})")
220
- terminal_ok = ("completed", "succeeded", "success", "done")
221
- terminal_bad = ("failed", "error", "cancelled", "canceled", "stopped")
222
- while True:
223
- st = _job_status(job_id)
224
- if st:
225
- print(f"[push_to_jobs] 状态: {st}")
226
- if st in terminal_ok:
227
- return 0
228
- if st in terminal_bad:
229
- return 1
230
- time.sleep(poll_s)
231
-
232
-
233
- def _submit_detach(cmd: list[str]) -> tuple[int, str]:
234
- r = _run_capture(cmd)
235
- out = (r.stdout or "") + "\n" + (r.stderr or "")
236
- if r.returncode != 0:
237
- print(out)
238
- return r.returncode, ""
239
- jid = _parse_detach_job_id(r.stdout or "", r.stderr or "")
240
- if jid:
241
- print(f"[push_to_jobs] Job ID: {jid}")
242
- else:
243
- print(out)
244
- return r.returncode, jid or ""
245
-
246
-
247
- def main() -> None:
248
- p = argparse.ArgumentParser(description="Submit HF Jobs: copy / extract / train.")
249
- p.add_argument("--bucket", default=DEFAULT_BUCKET)
250
- p.add_argument("--code-repo", default=DEFAULT_CODE, dest="code_repo")
251
- p.add_argument("--weights-repo", default=DEFAULT_WEIGHTS, dest="weights_repo")
252
- p.add_argument("--src", default=DEFAULT_SRC, dest="src")
253
- p.add_argument("--mirror-subpath", default=DEFAULT_MIRROR, dest="mirror_subpath")
254
- p.add_argument(
255
- "--extracted-subpath",
256
- default=DEFAULT_EXTRACTED,
257
- dest="extracted_subpath",
258
- help="解压/训练用的数据集根路径(默认 extracted/cosmos_hub,与 mirror 分离)",
259
- )
260
- p.add_argument(
261
- "--cache-subpath",
262
- default=DEFAULT_CACHE,
263
- dest="cache_subpath",
264
- help="Bucket 内缓存根目录(默认 cache,可整体删除)",
265
- )
266
- p.add_argument("--timeout", default=DEFAULT_TIMEOUT)
267
- p.add_argument("--detach", action="store_true")
268
- p.add_argument("--train-image", default=TRAIN_IMAGE, dest="train_image")
269
- p.add_argument("--train-flavor", default=TRAIN_FLAVOR, dest="train_flavor")
270
- p.add_argument(
271
- "action",
272
- choices=("copy", "extract", "train", "print-plan", "copy-extract"),
273
- )
274
- args = p.parse_args()
275
-
276
- builders = {
277
- "copy": build_copy_cmd,
278
- "extract": build_extract_cmd,
279
- "train": build_train_cmd,
280
- }
281
- if args.action == "print-plan":
282
- for name, b in builders.items():
283
- print(f"--- {name} ---")
284
- print(" ".join(shlex.quote(c) for c in b(args)))
285
- print()
286
- return
287
-
288
- if args.action == "copy-extract":
289
- cargs = argparse.Namespace(**vars(args))
290
- cargs.detach = True
291
- rc, jid = _submit_detach(build_copy_cmd(cargs))
292
- if rc != 0:
293
- sys.exit(rc)
294
- if not jid:
295
- print("[push_to_jobs] 未解析到 copy 的 Job ID。请到 Hub Jobs 查看后手动: python scripts/push_to_jobs.py extract")
296
- sys.exit(1)
297
- rc = _wait_job(jid, label="copy to bucket")
298
- if rc != 0:
299
- sys.exit(rc)
300
- sys.exit(_run(build_extract_cmd(args)))
301
-
302
- sys.exit(_run(builders[args.action](args)))
303
-
304
-
305
- if __name__ == "__main__":
306
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """提交 Hugging Face Jobs:数据集服务端 copy → CPU 挂载解压 → A10G-Large 正式训练。
2
+
3
+ **持久化只在 Bucket 上**(挂载如 ``/mnt/wjad``)。容器本地只应保留镜像层;
4
+ ``pip`` / ``TMPDIR`` / ``HF_HOME`` 等均指向 ``cache/`` 避免 ephemeral 配额(如 50G)被撑爆
5
+
6
+ Bucket 目录约定(相对挂载根)::
7
+
8
+ - ``mirror/cosmos_hub/`` — ``copy_files`` 得到的 Hub 数据集镜像;**只作解压源**,不把解压产物写回这里。
9
+ - ``extracted/cosmos_hub/`` — **仅解压输出**,相对路径与源一致,训练 ``--data_root`` 指向这里。
10
+ - ``cache/`` — pip、TMPDIR、HF/torch/transformers 缓存及 ``wjad_repo``(clone 代码),**可整目录删** 以清缓存或换代码版本
11
+ - ``runs/current/`` — checkpoint(仍在 Bucket);另可 ``upload_file`` 到 Hub model repo。
12
+
13
+ 常用命令::
14
+
15
+ python scripts/push_to_jobs.py copy-extract # 先 copy(等待成功)再 submit extract(前台跑完日志)
16
+
17
+ - ``--replace-extracted`` 解压前 **删除** ``extracted/...`` 下全部内容(含 ``_wjad_extract_state``),再全量重解压
18
+
19
+ 单独步骤可加 ``--detach``。Windows 上会使用 ``sys.executable`` 同目录的 ``hf.exe``。
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import json
26
+ import re
27
+ import shlex
28
+ import shutil
29
+ import subprocess
30
+ import sys
31
+ import time
32
+ from pathlib import Path
33
+
34
+ DEFAULT_BUCKET = "fuzirui/WJAD"
35
+ DEFAULT_CODE = "fuzirui/WJAD"
36
+ DEFAULT_WEIGHTS = "fuzirui/WJAD"
37
+ DEFAULT_SRC = "hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/"
38
+ DEFAULT_MIRROR = "mirror/cosmos_hub"
39
+ DEFAULT_EXTRACTED = "extracted/cosmos_hub"
40
+ DEFAULT_CACHE = "cache"
41
+ DEFAULT_TIMEOUT = "7d"
42
+ TRAIN_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-runtime"
43
+ TRAIN_FLAVOR = "a10g-large"
44
+
45
+
46
+ def _hf_executable() -> str:
47
+ sibling = Path(sys.executable).parent / ("hf.exe" if sys.platform == "win32" else "hf")
48
+ if sibling.is_file():
49
+ return str(sibling.resolve())
50
+ w = shutil.which("hf")
51
+ return w if w else "hf"
52
+
53
+
54
+ def _hf_jobs_cmd(parts: list[str]) -> list[str]:
55
+ if parts and parts[0] == "hf":
56
+ return [_hf_executable(), *parts[1:]]
57
+ return parts
58
+
59
+
60
+ def _insert_detach(cmd: list[str], detach: bool) -> None:
61
+ if not detach:
62
+ return
63
+ idx = cmd.index("run") + 1
64
+ cmd.insert(idx, "--detach")
65
+
66
+
67
+ def build_copy_cmd(args: argparse.Namespace) -> list[str]:
68
+ dst = f"hf://buckets/{args.bucket}/{args.mirror_subpath.rstrip('/')}/"
69
+ one_py = (
70
+ "import subprocess,sys;"
71
+ "subprocess.check_call([sys.executable,'-m','pip','install','-q','-U','huggingface_hub']);"
72
+ "from huggingface_hub import copy_files;"
73
+ f"copy_files({args.src!r},{dst!r})"
74
+ )
75
+ cmd = [
76
+ "hf",
77
+ "jobs",
78
+ "run",
79
+ "--flavor",
80
+ "cpu-basic",
81
+ "--timeout",
82
+ args.timeout,
83
+ "--secrets",
84
+ "HF_TOKEN",
85
+ "python:3.12",
86
+ "python",
87
+ "-c",
88
+ one_py,
89
+ ]
90
+ _insert_detach(cmd, args.detach)
91
+ return cmd
92
+
93
+
94
+ def _bucket_cache_env_sh(cache_subpath: str) -> str:
95
+ """单行:pip / 临时目录 / HF 相关缓存全部落在挂载盘 cache/ 下。"""
96
+ c = f"/mnt/wjad/{cache_subpath}"
97
+ return (
98
+ f"export WJAD_CACHE_ROOT={c} && "
99
+ f"mkdir -p {c}/pip {c}/tmp {c}/hf {c}/transformers {c}/torch {c}/datasets {c}/xdg && "
100
+ f"export PIP_CACHE_DIR={c}/pip && export TMPDIR={c}/tmp && export TEMP={c}/tmp && export TMP={c}/tmp && "
101
+ f"export HF_HOME={c}/hf && export TRANSFORMERS_CACHE={c}/transformers && export TORCH_HOME={c}/torch && "
102
+ f"export HF_DATASETS_CACHE={c}/datasets && export XDG_CACHE_HOME={c}/xdg"
103
+ )
104
+
105
+
106
+ def build_extract_cmd(args: argparse.Namespace) -> list[str]:
107
+ vol = f"hf://buckets/{args.bucket}:/mnt/wjad"
108
+ cache_setup = _bucket_cache_env_sh(args.cache_subpath)
109
+ clone_dir = f"/mnt/wjad/{args.cache_subpath}/wjad_repo"
110
+ ext_dir = f"/mnt/wjad/{args.extracted_subpath}"
111
+ steps: list[str] = [
112
+ "set -e",
113
+ "export WJAD_BUCKET_MOUNT=/mnt/wjad",
114
+ cache_setup,
115
+ ]
116
+ if getattr(args, "replace_extracted", False):
117
+ steps.append(
118
+ f'echo "[extract] --replace-extracted: wipe {ext_dir}" && rm -rf {ext_dir} && mkdir -p {ext_dir}'
119
+ )
120
+ steps.extend(
121
+ [
122
+ "command -v git >/dev/null 2>&1 || (apt-get update && apt-get install -y --no-install-recommends git)",
123
+ "pip install -q -U huggingface_hub",
124
+ f"rm -rf {clone_dir} && git clone https://oauth2:$HF_TOKEN@huggingface.co/{args.code_repo} {clone_dir}",
125
+ f"cd {clone_dir}",
126
+ "python scripts/jobs_extract_archives.py "
127
+ f"--scan-root /mnt/wjad/{args.mirror_subpath} --out-root {ext_dir}",
128
+ ]
129
+ )
130
+ inner = " && ".join(steps)
131
+ cmd = [
132
+ "hf",
133
+ "jobs",
134
+ "run",
135
+ "-v",
136
+ vol,
137
+ "--flavor",
138
+ "cpu-basic",
139
+ "--timeout",
140
+ args.timeout,
141
+ "--secrets",
142
+ "HF_TOKEN",
143
+ "python:3.12",
144
+ "sh",
145
+ "-c",
146
+ inner,
147
+ ]
148
+ _insert_detach(cmd, args.detach)
149
+ return cmd
150
+
151
+
152
+ def build_train_cmd(args: argparse.Namespace) -> list[str]:
153
+ vol = f"hf://buckets/{args.bucket}:/mnt/wjad"
154
+ cache_setup = _bucket_cache_env_sh(args.cache_subpath)
155
+ clone_dir = f"/mnt/wjad/{args.cache_subpath}/wjad_repo"
156
+ inner = " && ".join(
157
+ [
158
+ "set -e",
159
+ "export WJAD_BUCKET_MOUNT=/mnt/wjad",
160
+ cache_setup,
161
+ f"export WJAD_HUB_REPO={args.weights_repo}",
162
+ f"export WJAD_DATA_ROOT=/mnt/wjad/{args.extracted_subpath}",
163
+ "export WJAD_OUTPUT_DIR=/mnt/wjad/runs/current",
164
+ "command -v git >/dev/null 2>&1 || (apt-get update && apt-get install -y --no-install-recommends git)",
165
+ "pip install -q -U huggingface_hub",
166
+ f"rm -rf {clone_dir} && git clone https://oauth2:$HF_TOKEN@huggingface.co/{args.code_repo} {clone_dir}",
167
+ f"cd {clone_dir}",
168
+ "pip install -q -U pip",
169
+ "pip install -q -e .",
170
+ "bash scripts/jobs_entry_train.sh",
171
+ ]
172
+ )
173
+ cmd = [
174
+ "hf",
175
+ "jobs",
176
+ "run",
177
+ "-v",
178
+ vol,
179
+ "--flavor",
180
+ args.train_flavor,
181
+ "--timeout",
182
+ args.timeout,
183
+ "--secrets",
184
+ "HF_TOKEN",
185
+ args.train_image,
186
+ "bash",
187
+ "-lc",
188
+ inner,
189
+ ]
190
+ _insert_detach(cmd, args.detach)
191
+ return cmd
192
+
193
+
194
+ def _run(cmd: list[str]) -> int:
195
+ print("[push_to_jobs] $", " ".join(shlex.quote(c) for c in cmd))
196
+ cmd = _hf_jobs_cmd(cmd)
197
+ return subprocess.call(cmd)
198
+
199
+
200
+ def _run_capture(cmd: list[str]) -> subprocess.CompletedProcess:
201
+ cmd = _hf_jobs_cmd(cmd)
202
+ print("[push_to_jobs] $", " ".join(shlex.quote(c) for c in cmd))
203
+ return subprocess.run(cmd, capture_output=True, text=True)
204
+
205
+
206
+ def _parse_detach_job_id(stdout: str, stderr: str) -> str | None:
207
+ text = (stdout + "\n" + stderr).strip()
208
+ m = re.search(r"\bID:\s*([a-fA-F0-9]{12,})\b", text)
209
+ if m:
210
+ return m.group(1)
211
+ m = re.search(r"/jobs/[^/\s]+/([a-fA-F0-9]{12,})", text)
212
+ if m:
213
+ return m.group(1)
214
+ for line in text.splitlines():
215
+ line = line.strip()
216
+ low = line.lower()
217
+ if "job" in low and "id" in low:
218
+ parts = line.replace(":", " ").split()
219
+ for i, p in enumerate(parts):
220
+ if p.lower() == "id" and i + 1 < len(parts):
221
+ return parts[i + 1].strip().rstrip(",")
222
+ return None
223
+
224
+
225
+ def _job_status(job_id: str) -> str | None:
226
+ r = _run_capture(["hf", "jobs", "inspect", job_id, "--json"])
227
+ if r.returncode != 0 or not r.stdout.strip():
228
+ return None
229
+ try:
230
+ data = json.loads(r.stdout)
231
+ row = data[0] if isinstance(data, list) and data else data if isinstance(data, dict) else None
232
+ if not isinstance(row, dict):
233
+ return None
234
+ st = row.get("status")
235
+ if isinstance(st, dict):
236
+ stage = st.get("stage")
237
+ if isinstance(stage, str):
238
+ return stage.lower()
239
+ if isinstance(st, str):
240
+ return st.lower()
241
+ except json.JSONDecodeError:
242
+ pass
243
+ return None
244
+
245
+
246
+ def _wait_job(job_id: str, poll_s: float = 45.0, label: str = "") -> int:
247
+ print(f"[push_to_jobs] 轮询 Job: {job_id} ({label})")
248
+ terminal_ok = ("completed", "succeeded", "success", "done")
249
+ terminal_bad = ("failed", "error", "cancelled", "canceled", "stopped")
250
+ while True:
251
+ st = _job_status(job_id)
252
+ if st:
253
+ print(f"[push_to_jobs] 状态: {st}")
254
+ if st in terminal_ok:
255
+ return 0
256
+ if st in terminal_bad:
257
+ return 1
258
+ time.sleep(poll_s)
259
+
260
+
261
+ def _submit_detach(cmd: list[str]) -> tuple[int, str]:
262
+ r = _run_capture(cmd)
263
+ out = (r.stdout or "") + "\n" + (r.stderr or "")
264
+ if r.returncode != 0:
265
+ print(out)
266
+ return r.returncode, ""
267
+ jid = _parse_detach_job_id(r.stdout or "", r.stderr or "")
268
+ if jid:
269
+ print(f"[push_to_jobs] Job ID: {jid}")
270
+ else:
271
+ print(out)
272
+ return r.returncode, jid or ""
273
+
274
+
275
+ def main() -> None:
276
+ p = argparse.ArgumentParser(description="Submit HF Jobs: copy / extract / train.")
277
+ p.add_argument("--bucket", default=DEFAULT_BUCKET)
278
+ p.add_argument("--code-repo", default=DEFAULT_CODE, dest="code_repo")
279
+ p.add_argument("--weights-repo", default=DEFAULT_WEIGHTS, dest="weights_repo")
280
+ p.add_argument("--src", default=DEFAULT_SRC, dest="src")
281
+ p.add_argument("--mirror-subpath", default=DEFAULT_MIRROR, dest="mirror_subpath")
282
+ p.add_argument(
283
+ "--extracted-subpath",
284
+ default=DEFAULT_EXTRACTED,
285
+ dest="extracted_subpath",
286
+ help="解压/训练用的数据集根路径(默认 extracted/cosmos_hub,与 mirror 分离)",
287
+ )
288
+ p.add_argument(
289
+ "--cache-subpath",
290
+ default=DEFAULT_CACHE,
291
+ dest="cache_subpath",
292
+ help="Bucket 内缓存根目录(默认 cache,可整体删除)",
293
+ )
294
+ p.add_argument("--timeout", default=DEFAULT_TIMEOUT)
295
+ p.add_argument("--detach", action="store_true")
296
+ p.add_argument(
297
+ "--replace-extracted",
298
+ action="store_true",
299
+ help="仅 extract:先 rm -rf extracted 目标目录再解压,完全替换旧解压结果",
300
+ )
301
+ p.add_argument("--train-image", default=TRAIN_IMAGE, dest="train_image")
302
+ p.add_argument("--train-flavor", default=TRAIN_FLAVOR, dest="train_flavor")
303
+ p.add_argument(
304
+ "action",
305
+ choices=("copy", "extract", "train", "print-plan", "copy-extract"),
306
+ )
307
+ args = p.parse_args()
308
+
309
+ builders = {
310
+ "copy": build_copy_cmd,
311
+ "extract": build_extract_cmd,
312
+ "train": build_train_cmd,
313
+ }
314
+ if args.action == "print-plan":
315
+ for name, b in builders.items():
316
+ print(f"--- {name} ---")
317
+ print(" ".join(shlex.quote(c) for c in b(args)))
318
+ print()
319
+ return
320
+
321
+ if args.action == "copy-extract":
322
+ cargs = argparse.Namespace(**vars(args))
323
+ cargs.detach = True
324
+ rc, jid = _submit_detach(build_copy_cmd(cargs))
325
+ if rc != 0:
326
+ sys.exit(rc)
327
+ if not jid:
328
+ print("[push_to_jobs] 未解析到 copy 的 Job ID。请到 Hub Jobs 查看后手动: python scripts/push_to_jobs.py extract")
329
+ sys.exit(1)
330
+ rc = _wait_job(jid, label="copy to bucket")
331
+ if rc != 0:
332
+ sys.exit(rc)
333
+ sys.exit(_run(build_extract_cmd(args)))
334
+
335
+ sys.exit(_run(builders[args.action](args)))
336
+
337
+
338
+ if __name__ == "__main__":
339
+ main()