fuzirui commited on
Commit
2e8e068
·
verified ·
1 Parent(s): 9200c30

Upload folder using huggingface_hub

Browse files
configs/cosmos_hub_extract_lidar_and_generation_only.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # 三个并行 Job:① lidar_raw ② pose(磁盘问题重试)③ single_view 内仅 generation(分卷),不碰 caption/hdmap
2
+ # extract-parallel --replace-subdirs-file ... --shards-file 本文件
3
+
4
+ lidar_raw
5
+ pose
6
+ cosmos_synthetic/single_view|only=generation
configs/cosmos_hub_replace_lidar_and_generation.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # 重试前仅删除下列 extracted 子树,其它不动(路径相对 extracted-subpath,默认 extracted/cosmos_hub)
2
+
3
+ lidar_raw
4
+ pose
5
+ cosmos_synthetic/single_view/generation
scripts/jobs_extract_archives.py CHANGED
@@ -7,6 +7,15 @@
7
  解压产物只写在 ``--out-root`` 下;**不在 mirror 里落任何文件**(进度标记也放在
8
  ``out-root/_wjad_extract_state/``,与归档相对路径对应,避免修改 ``--scan-root``)。
9
 
 
 
 
 
 
 
 
 
 
10
  示例(mount ``hf://buckets/.../WJAD`` → ``/mnt/wjad``)::
11
 
12
  python scripts/jobs_extract_archives.py \\
@@ -17,10 +26,144 @@
17
  from __future__ import annotations
18
 
19
  import argparse
 
 
 
 
 
20
  import tarfile
 
 
21
  import zipfile
 
22
  from pathlib import Path
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def _archive_stem(path: Path) -> str:
26
  n = path.name
@@ -42,9 +185,6 @@ def _is_archive(path: Path) -> bool:
42
  )
43
 
44
 
45
- STATE_DIRNAME = "_wjad_extract_state"
46
-
47
-
48
  def _done_marker_path(archive: Path, scan: Path, out_root: Path) -> Path:
49
  """标记只写在 out_root 下,绝不写回 mirror。"""
50
  rel = archive.relative_to(scan)
@@ -87,31 +227,86 @@ def _extract_one(archive: Path, dest_dir: Path) -> None:
87
  tf.extractall(dest_dir)
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def main() -> None:
91
  p = argparse.ArgumentParser()
92
  p.add_argument("--scan-root", type=Path, required=True)
93
  p.add_argument("--out-root", type=Path, required=True)
 
 
 
 
 
94
  args = p.parse_args()
 
95
  scan: Path = args.scan_root.resolve()
96
  out_root: Path = args.out_root.resolve()
97
  if not scan.is_dir():
98
  raise SystemExit(f"--scan-root not a directory: {scan}")
99
  _validate_roots(scan, out_root)
 
 
 
 
100
 
101
  count = 0
102
  for path in sorted(scan.rglob("*")):
103
  if not path.is_file() or not _is_archive(path):
104
  continue
 
 
 
 
 
105
  mpath = _done_marker_path(path, scan, out_root)
106
  if mpath.exists():
107
  continue
108
  rel_parent = path.parent.relative_to(scan)
109
- dest = out_root / rel_parent / _archive_stem(path)
110
  print(f"[extract] {path.relative_to(scan)} -> {dest.relative_to(out_root)}", flush=True)
111
  _extract_one(path, dest)
 
112
  mpath.parent.mkdir(parents=True, exist_ok=True)
113
  mpath.write_text("ok\n", encoding="utf-8")
114
  count += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  print(f"[extract] done, {count} archives", flush=True)
116
 
117
 
 
7
  解压产物只写在 ``--out-root`` 下;**不在 mirror 里落任何文件**(进度标记也放在
8
  ``out-root/_wjad_extract_state/``,与归档相对路径对应,避免修改 ``--scan-root``)。
9
 
10
+ 在 HF Job 内若设置了 ``WJAD_CACHE_ROOT`` / ``TMPDIR``(见 ``push_to_jobs`` 的 cache 环境),
11
+ 本脚本会在**每解压完一个归档后**清空 ``TMPDIR`` 下的临时内容,降低容器 ephemeral 占用。
12
+
13
+ **分卷归档**(与 Cosmos README 中 ``generation.tar.gz.part-*`` 一致):同目录下若干片文件,
14
+ 按分卷序号排序后 **cat 拼成字节流**,再交给 ``tar`` 从 stdin 解压,**不先合并成大文件**(省磁盘)。
15
+
16
+ 可选 ``--only-stems a,b``(或环境变量 ``WJAD_EXTRACT_ONLY_STEMS``):只解压这些「目标目录名」对应的归档
17
+ (即 ``_archive_stem``:如 ``generation.tar.gz`` / 分卷逻辑名 ``generation.tar.gz`` → stem ``generation``)。
18
+
19
  示例(mount ``hf://buckets/.../WJAD`` → ``/mnt/wjad``)::
20
 
21
  python scripts/jobs_extract_archives.py \\
 
26
  from __future__ import annotations
27
 
28
  import argparse
29
+ import gc
30
+ import os
31
+ import re
32
+ import shutil
33
+ import subprocess
34
  import tarfile
35
+ import tempfile
36
+ import threading
37
  import zipfile
38
+ from collections import defaultdict
39
  from pathlib import Path
40
 
41
+ STATE_DIRNAME = "_wjad_extract_state"
42
+
43
+
44
+ def _ensure_tmp_on_bucket() -> None:
45
+ """尽量把 Python 临时目录指到挂载盘,减少容器根分区(ephemeral)占用。"""
46
+ root = os.environ.get("WJAD_CACHE_ROOT", "").strip()
47
+ if root:
48
+ tmp = Path(root) / "tmp"
49
+ tmp.mkdir(parents=True, exist_ok=True)
50
+ os.environ.setdefault("TMPDIR", str(tmp))
51
+ os.environ.setdefault("TEMP", str(tmp))
52
+ os.environ.setdefault("TMP", str(tmp))
53
+ td = os.environ.get("TMPDIR", "").strip()
54
+ if td:
55
+ tempfile.tempdir = td
56
+
57
+
58
+ def _clean_tmpdir_after_archive() -> None:
59
+ """每个归档解压后清空 TMPDIR 内容,避免 zip/tar/解压链累积临时文件撑爆本地盘。"""
60
+ raw = os.environ.get("TMPDIR", "").strip()
61
+ if not raw:
62
+ c = os.environ.get("WJAD_CACHE_ROOT", "").strip()
63
+ if c:
64
+ raw = str(Path(c) / "tmp")
65
+ if not raw:
66
+ return
67
+ tmp = Path(raw)
68
+ if not tmp.is_dir():
69
+ return
70
+ for child in list(tmp.iterdir()):
71
+ try:
72
+ if child.is_dir():
73
+ shutil.rmtree(child, ignore_errors=True)
74
+ else:
75
+ child.unlink(missing_ok=True)
76
+ except OSError:
77
+ pass
78
+ gc.collect()
79
+
80
+
81
+ # 例如 generation.tar.gz.part-aa → 逻辑名 generation.tar.gz,分卷键 aa
82
+ _SPLIT_PART_RE = re.compile(r"(?i)^(?P<base>.+)\.part-(?P<suf>[a-z0-9]+)$")
83
+
84
+
85
+ def _split_part_info(path: Path) -> tuple[str, str] | None:
86
+ m = _SPLIT_PART_RE.match(path.name)
87
+ if not m:
88
+ return None
89
+ return m.group("base"), m.group("suf")
90
+
91
+
92
+ def _split_part_sort_key(path: Path) -> tuple[int, int | str]:
93
+ info = _split_part_info(path)
94
+ if not info:
95
+ return (2, "")
96
+ suf = info[1]
97
+ if suf.isdigit():
98
+ return (0, int(suf))
99
+ return (1, suf)
100
+
101
+
102
+ def _collect_split_groups(scan: Path) -> dict[tuple[Path, str], list[Path]]:
103
+ """同一目录、同一逻辑归档名的分卷归为一组。"""
104
+ groups: dict[tuple[Path, str], list[Path]] = defaultdict(list)
105
+ for path in scan.rglob("*"):
106
+ if not path.is_file():
107
+ continue
108
+ info = _split_part_info(path)
109
+ if info:
110
+ logical, _suf = info
111
+ groups[(path.parent.resolve(), logical)].append(path)
112
+ return dict(groups)
113
+
114
+
115
+ def _tar_stdin_args(logical_name: str, dest_dir: Path) -> list[str]:
116
+ """构造「从 stdin 读压缩 tar」的 tar 参数(依赖系统 tar,与 HF python 镜像一致)。"""
117
+ lower = logical_name.lower()
118
+ dest = str(dest_dir)
119
+ if lower.endswith((".tar.gz", ".tgz")):
120
+ return ["tar", "-xzf", "-", "-C", dest]
121
+ if lower.endswith(".tar.bz2"):
122
+ return ["tar", "-xjf", "-", "-C", dest]
123
+ if lower.endswith(".tar.xz"):
124
+ return ["tar", "-xJf", "-", "-C", dest]
125
+ if lower.endswith(".tar"):
126
+ return ["tar", "-xf", "-", "-C", dest]
127
+ raise ValueError(f"不支持的拼接分卷类型(仅 tar / tar.*): {logical_name!r}")
128
+
129
+
130
+ def _extract_split_volumes(parts: list[Path], logical_name: str, dest_dir: Path) -> None:
131
+ """顺序读分卷写入 tar stdin,流式解压,不合并落地文件(POSIX / Windows 均可)。"""
132
+ if not parts:
133
+ raise ValueError("分卷列表为空")
134
+ dest_dir.mkdir(parents=True, exist_ok=True)
135
+ tar_args = _tar_stdin_args(logical_name, dest_dir)
136
+ proc = subprocess.Popen(
137
+ tar_args,
138
+ stdin=subprocess.PIPE,
139
+ stderr=subprocess.PIPE,
140
+ )
141
+ assert proc.stdin is not None
142
+
143
+ def _feed() -> None:
144
+ try:
145
+ for p in parts:
146
+ with p.open("rb") as f:
147
+ shutil.copyfileobj(f, proc.stdin, length=1024 * 1024 * 8)
148
+ finally:
149
+ try:
150
+ proc.stdin.close()
151
+ except OSError:
152
+ pass
153
+
154
+ feeder = threading.Thread(target=_feed, daemon=True)
155
+ feeder.start()
156
+ _out, terr = proc.communicate()
157
+ feeder.join(timeout=30)
158
+ if proc.returncode != 0:
159
+ msg = (terr or b"").decode(errors="replace")
160
+ raise RuntimeError(f"tar 解压分卷流失败 (exit {proc.returncode}): {msg}")
161
+
162
+
163
+ def _done_marker_split(scan: Path, out_root: Path, group_parent: Path, logical_name: str) -> Path:
164
+ rel_parent = group_parent.resolve().relative_to(scan)
165
+ return out_root / STATE_DIRNAME / rel_parent / f"{logical_name}.wjad_done"
166
+
167
 
168
  def _archive_stem(path: Path) -> str:
169
  n = path.name
 
185
  )
186
 
187
 
 
 
 
188
  def _done_marker_path(archive: Path, scan: Path, out_root: Path) -> Path:
189
  """标记只写在 out_root 下,绝不写回 mirror。"""
190
  rel = archive.relative_to(scan)
 
227
  tf.extractall(dest_dir)
228
 
229
 
230
+ def _only_stems_set(args: argparse.Namespace) -> set[str] | None:
231
+ raw = (getattr(args, "only_stems", None) or "").strip()
232
+ if not raw:
233
+ raw = os.environ.get("WJAD_EXTRACT_ONLY_STEMS", "").strip()
234
+ if not raw:
235
+ return None
236
+ stems = {x.strip() for x in raw.split(",") if x.strip()}
237
+ return stems or None
238
+
239
+
240
+ def _stem_ok(only: set[str] | None, stem: str) -> bool:
241
+ return only is None or stem in only
242
+
243
+
244
  def main() -> None:
245
  p = argparse.ArgumentParser()
246
  p.add_argument("--scan-root", type=Path, required=True)
247
  p.add_argument("--out-root", type=Path, required=True)
248
+ p.add_argument(
249
+ "--only-stems",
250
+ default=None,
251
+ help="逗号分隔,只解压这些 stem(与解压目标目录名一致,如 generation)",
252
+ )
253
  args = p.parse_args()
254
+ only = _only_stems_set(args)
255
  scan: Path = args.scan_root.resolve()
256
  out_root: Path = args.out_root.resolve()
257
  if not scan.is_dir():
258
  raise SystemExit(f"--scan-root not a directory: {scan}")
259
  _validate_roots(scan, out_root)
260
+ _ensure_tmp_on_bucket()
261
+
262
+ split_groups = _collect_split_groups(scan)
263
+ part_paths: set[Path] = {p for plist in split_groups.values() for p in plist}
264
 
265
  count = 0
266
  for path in sorted(scan.rglob("*")):
267
  if not path.is_file() or not _is_archive(path):
268
  continue
269
+ if path in part_paths:
270
+ continue
271
+ stem = _archive_stem(path)
272
+ if not _stem_ok(only, stem):
273
+ continue
274
  mpath = _done_marker_path(path, scan, out_root)
275
  if mpath.exists():
276
  continue
277
  rel_parent = path.parent.relative_to(scan)
278
+ dest = out_root / rel_parent / stem
279
  print(f"[extract] {path.relative_to(scan)} -> {dest.relative_to(out_root)}", flush=True)
280
  _extract_one(path, dest)
281
+ _clean_tmpdir_after_archive()
282
  mpath.parent.mkdir(parents=True, exist_ok=True)
283
  mpath.write_text("ok\n", encoding="utf-8")
284
  count += 1
285
+
286
+ for (gparent, logical) in sorted(
287
+ split_groups,
288
+ key=lambda k: (str(k[0].resolve().relative_to(scan)), k[1].lower()),
289
+ ):
290
+ stem = _archive_stem(Path(logical))
291
+ if not _stem_ok(only, stem):
292
+ continue
293
+ parts_raw = split_groups[(gparent, logical)]
294
+ parts_sorted = sorted(parts_raw, key=_split_part_sort_key)
295
+ mpath = _done_marker_split(scan, out_root, gparent, logical)
296
+ if mpath.exists():
297
+ continue
298
+ rel_parent = gparent.resolve().relative_to(scan)
299
+ dest = out_root / rel_parent / stem
300
+ print(
301
+ f"[extract-split] {logical} ({len(parts_sorted)} parts) -> {dest.relative_to(out_root)}",
302
+ flush=True,
303
+ )
304
+ _extract_split_volumes(parts_sorted, logical, dest)
305
+ _clean_tmpdir_after_archive()
306
+ mpath.parent.mkdir(parents=True, exist_ok=True)
307
+ mpath.write_text("ok\n", encoding="utf-8")
308
+ count += 1
309
+
310
  print(f"[extract] done, {count} archives", flush=True)
311
 
312
 
scripts/push_to_jobs.py CHANGED
@@ -19,7 +19,8 @@ Bucket 目录约定(相对挂载根)::
19
  - ``inspect-extract-job <job_id>`` — 从 ``hf jobs inspect --json`` 中解析 ``--scan-root``,列出该解压 Job 负责的顶层文件夹名。
20
  - ``extract-parallel`` — 多个并行解压 Job;**合并**为各子目录在 ``extracted/.../<名>/`` 下并列。
21
  ``--shard-dirs a,b,c`` 表示 **3 个 Job**(各处理一个顶层子文件夹)。
22
- ``--shards-file`` 中 **一个 Job**;行内用逗号分隔的多个目录会在 **同一 Job 内顺序解压**(例如 6 行 → 6 个并行 Job)。
 
23
 
24
  单独步骤可加 ``--detach``。Windows 上会使用 ``sys.executable`` 同目录的 ``hf.exe``。
25
  """
@@ -98,14 +99,22 @@ def build_copy_cmd(args: argparse.Namespace) -> list[str]:
98
 
99
 
100
  def _bucket_cache_env_sh(cache_subpath: str) -> str:
101
- """单行:pip / 临时目录 / HF 相关缓存全部落在挂载盘 cache/ 下。"""
 
 
 
102
  c = f"/mnt/wjad/{cache_subpath}"
103
  return (
104
  f"export WJAD_CACHE_ROOT={c} && "
105
- f"mkdir -p {c}/pip {c}/tmp {c}/hf {c}/transformers {c}/torch {c}/datasets {c}/xdg && "
 
 
 
 
106
  f"export PIP_CACHE_DIR={c}/pip && export TMPDIR={c}/tmp && export TEMP={c}/tmp && export TMP={c}/tmp && "
107
  f"export HF_HOME={c}/hf && export TRANSFORMERS_CACHE={c}/transformers && export TORCH_HOME={c}/torch && "
108
- f"export HF_DATASETS_CACHE={c}/datasets && export XDG_CACHE_HOME={c}/xdg"
 
109
  )
110
 
111
 
@@ -126,7 +135,7 @@ def _sanitize_clone_tag_for_group(group: list[str]) -> str:
126
 
127
 
128
  def _safe_rel_subdir_segment(name: str) -> str:
129
- """mirror/extracted 下单层子目录名,禁止路径穿越。"""
130
  n = name.strip()
131
  if not n or n in (".", "..") or "/" in n or "\\" in n or n.startswith("-"):
132
  raise ValueError(f"非法子目录名(禁止路径穿越): {name!r}")
@@ -135,6 +144,20 @@ def _safe_rel_subdir_segment(name: str) -> str:
135
  return n
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def build_wipe_extracted_cmd(args: argparse.Namespace) -> list[str]:
139
  """删掉整个 extracted 根目录(并行分片前或单 Job 全量替换前)。"""
140
  vol = f"hf://buckets/{args.bucket}:/mnt/wjad"
@@ -167,7 +190,7 @@ def build_wipe_extracted_subdirs_cmd(args: argparse.Namespace, subdirs: list[str
167
  ext_base = f"/mnt/wjad/{args.extracted_subpath.rstrip('/')}"
168
  rm_parts = [
169
  f'echo "[wipe-subdir] {ext_base}/{d}" && rm -rf {ext_base}/{d}'
170
- for d in (_safe_rel_subdir_segment(s) for s in subdirs)
171
  ]
172
  inner = "set -e && " + " && ".join(rm_parts)
173
  cmd = [
@@ -229,18 +252,25 @@ def build_extract_cmd(
229
  for _, ed in pairs:
230
  steps.append(f"mkdir -p {ed}")
231
 
 
 
 
232
  extract_chain = " && ".join(
233
- f'echo "[extract] {sd} -> {ed}" && python scripts/jobs_extract_archives.py --scan-root {sd} --out-root {ed}'
234
  for sd, ed in pairs
235
  )
236
 
237
  steps.extend(
238
  [
239
- "command -v git >/dev/null 2>&1 || (apt-get update && apt-get install -y --no-install-recommends git)",
240
- "pip install -q -U huggingface_hub",
241
- f"rm -rf {clone_dir} && git clone https://oauth2:$HF_TOKEN@huggingface.co/{args.code_repo} {clone_dir}",
 
 
 
242
  f"cd {clone_dir}",
243
  extract_chain,
 
244
  ]
245
  )
246
  inner = " && ".join(steps)
@@ -277,12 +307,15 @@ def build_train_cmd(args: argparse.Namespace) -> list[str]:
277
  f"export WJAD_HUB_REPO={args.weights_repo}",
278
  f"export WJAD_DATA_ROOT=/mnt/wjad/{args.extracted_subpath}",
279
  "export WJAD_OUTPUT_DIR=/mnt/wjad/runs/current",
280
- "command -v git >/dev/null 2>&1 || (apt-get update && apt-get install -y --no-install-recommends git)",
281
- "pip install -q -U huggingface_hub",
282
- f"rm -rf {clone_dir} && git clone https://oauth2:$HF_TOKEN@huggingface.co/{args.code_repo} {clone_dir}",
 
 
 
283
  f"cd {clone_dir}",
284
- "pip install -q -U pip",
285
- "pip install -q -e .",
286
  "bash scripts/jobs_entry_train.sh",
287
  ]
288
  )
@@ -307,14 +340,33 @@ def build_train_cmd(args: argparse.Namespace) -> list[str]:
307
  return cmd
308
 
309
 
310
- def _load_shard_groups(args: argparse.Namespace) -> list[list[str]]:
311
- """CLI ``--shard-dirs a,b`` → 两个 Job``--shards-file`` 每行一个 Job,行内逗号为同 Job 多目录。"""
312
- groups: list[list[str]] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  if getattr(args, "shard_dirs", None):
314
  for p in args.shard_dirs.split(","):
315
  p = p.strip()
316
  if p:
317
- groups.append([p])
318
  sf = getattr(args, "shards_file", None)
319
  if sf:
320
  raw = Path(sf).expanduser().read_text(encoding="utf-8")
@@ -322,21 +374,22 @@ def _load_shard_groups(args: argparse.Namespace) -> list[list[str]]:
322
  line = line.strip()
323
  if not line or line.startswith("#"):
324
  continue
325
- parts = [x.strip() for x in line.split(",") if x.strip()]
 
326
  if parts:
327
- groups.append(parts)
328
- seen: set[tuple[str, ...]] = set()
329
- uniq: list[list[str]] = []
330
- for g in groups:
331
- t = tuple(g)
332
  if t not in seen:
333
  seen.add(t)
334
- uniq.append(g)
335
  return uniq
336
 
337
 
338
  def _load_replace_subdirs(args: argparse.Namespace) -> list[str]:
339
- """从 CLI / 文件收集要预先删除的 extracted 子目录单层名)。"""
340
  raw: list[str] = []
341
  rs = getattr(args, "replace_subdirs", None)
342
  if rs:
@@ -354,7 +407,7 @@ def _load_replace_subdirs(args: argparse.Namespace) -> list[str]:
354
  seen: set[str] = set()
355
  out: list[str] = []
356
  for s in raw:
357
- seg = _safe_rel_subdir_segment(s)
358
  if seg not in seen:
359
  seen.add(seg)
360
  out.append(seg)
@@ -401,11 +454,12 @@ def run_extract_parallel(args: argparse.Namespace) -> int:
401
  if _wait_job(jid, label="wipe extracted subdirs") != 0:
402
  return 1
403
  jids: list[tuple[str, str]] = []
404
- for grp in groups:
405
  label = ",".join(grp)
406
  sargs = argparse.Namespace(**vars(args))
407
  sargs.detach = True
408
  sargs.replace_extracted = False
 
409
  rc, jid = _submit_detach(build_extract_cmd(sargs, shard_group=grp))
410
  if rc != 0:
411
  return rc
@@ -567,13 +621,13 @@ def main() -> None:
567
  "--replace-subdirs",
568
  default=None,
569
  dest="replace_subdirs",
570
- help="extract-parallel:仅 rm -rf 下列子目录(逗号分隔,相对 extracted-subpath再提交解压",
571
  )
572
  p.add_argument(
573
  "--replace-subdirs-file",
574
  default=None,
575
  dest="replace_subdirs_file",
576
- help="extract-parallel:同上每行一个目录名;# 注释;行内逗号可多个",
577
  )
578
  p.add_argument(
579
  "--shard-dirs",
@@ -585,7 +639,13 @@ def main() -> None:
585
  "--shards-file",
586
  default=None,
587
  dest="shards_file",
588
- help="extract-parallel:每行一个 Job;行内逗号分隔个子目录(同 Job 内顺序解压);# 开头为注释",
 
 
 
 
 
 
589
  )
590
  p.add_argument("--train-image", default=TRAIN_IMAGE, dest="train_image")
591
  p.add_argument("--train-flavor", default=TRAIN_FLAVOR, dest="train_flavor")
 
19
  - ``inspect-extract-job <job_id>`` — 从 ``hf jobs inspect --json`` 中解析 ``--scan-root``,列出该解压 Job 负责的顶层文件夹名。
20
  - ``extract-parallel`` — 多个并行解压 Job;**合并**为各子目录在 ``extracted/.../<名>/`` 下并列。
21
  ``--shard-dirs a,b,c`` 表示 **3 个 Job**(各处理一个顶层子文件夹)。
22
+ ``--shards-file`` 支持每行后缀 ``|only=stem1,stem2``,传给 ``jobs_extract_archives --only-stems``,
23
+ 仅在当次扫描目录内解压这些归档(如仅 ``generation`` 分卷,跳过同目录的 ``caption``/``hdmap``)。
24
 
25
  单独步骤可加 ``--detach``。Windows 上会使用 ``sys.executable`` 同目录的 ``hf.exe``。
26
  """
 
99
 
100
 
101
  def _bucket_cache_env_sh(cache_subpath: str) -> str:
102
+ """单行:pip / 临时目录 / HF 相关缓存全部落在挂载盘 cache/ 下。
103
+
104
+ 将 HOME / PYTHONPYCACHEPREFIX 也指到 Bucket,避免任务写满容器根分区(ephemeral 50G)。
105
+ """
106
  c = f"/mnt/wjad/{cache_subpath}"
107
  return (
108
  f"export WJAD_CACHE_ROOT={c} && "
109
+ f"mkdir -p {c}/pip {c}/tmp {c}/hf {c}/transformers {c}/torch {c}/datasets {c}/xdg "
110
+ f"{c}/jobhome {c}/pycache && "
111
+ f"export HOME={c}/jobhome && "
112
+ f"export PYTHONPYCACHEPREFIX={c}/pycache && "
113
+ f"export PYTHONDONTWRITEBYTECODE=1 && "
114
  f"export PIP_CACHE_DIR={c}/pip && export TMPDIR={c}/tmp && export TEMP={c}/tmp && export TMP={c}/tmp && "
115
  f"export HF_HOME={c}/hf && export TRANSFORMERS_CACHE={c}/transformers && export TORCH_HOME={c}/torch && "
116
+ f"export HF_DATASETS_CACHE={c}/datasets && export XDG_CACHE_HOME={c}/xdg && "
117
+ f"export XDG_CONFIG_HOME={c}/xdg_config && export XDG_DATA_HOME={c}/xdg_data"
118
  )
119
 
120
 
 
135
 
136
 
137
  def _safe_rel_subdir_segment(name: str) -> str:
138
+ """单层子目录名(旧逻辑保留供需单段名的场景)。"""
139
  n = name.strip()
140
  if not n or n in (".", "..") or "/" in n or "\\" in n or n.startswith("-"):
141
  raise ValueError(f"非法子目录名(禁止路径穿越): {name!r}")
 
144
  return n
145
 
146
 
147
+ def _safe_rel_subpath(rel: str) -> str:
148
+ """extracted 下相对路径,可含多级,如 ``cosmos_synthetic/single_view/generation``。"""
149
+ n = rel.strip().strip("/")
150
+ if not n or n.startswith("-") or ".." in n:
151
+ raise ValueError(f"非法相对路径(禁止路径穿越): {rel!r}")
152
+ segments = [x for x in n.split("/") if x]
153
+ for seg in segments:
154
+ if seg in (".", ".."):
155
+ raise ValueError(f"非法路径段: {rel!r}")
156
+ if not re.fullmatch(r"[A-Za-z0-9][A-Za-z0-9_.-]*", seg):
157
+ raise ValueError(f"非法路径段(仅允许安全字符): {rel!r}")
158
+ return "/".join(segments)
159
+
160
+
161
  def build_wipe_extracted_cmd(args: argparse.Namespace) -> list[str]:
162
  """删掉整个 extracted 根目录(并行分片前或单 Job 全量替换前)。"""
163
  vol = f"hf://buckets/{args.bucket}:/mnt/wjad"
 
190
  ext_base = f"/mnt/wjad/{args.extracted_subpath.rstrip('/')}"
191
  rm_parts = [
192
  f'echo "[wipe-subdir] {ext_base}/{d}" && rm -rf {ext_base}/{d}'
193
+ for d in (_safe_rel_subpath(s) for s in subdirs)
194
  ]
195
  inner = "set -e && " + " && ".join(rm_parts)
196
  cmd = [
 
252
  for _, ed in pairs:
253
  steps.append(f"mkdir -p {ed}")
254
 
255
+ extra = getattr(args, "extract_only_stems", None) or ""
256
+ extra_arg = f" --only-stems {extra}" if extra else ""
257
+
258
  extract_chain = " && ".join(
259
+ f'echo "[extract] {sd} -> {ed}" && python scripts/jobs_extract_archives.py --scan-root {sd} --out-root {ed}{extra_arg}'
260
  for sd, ed in pairs
261
  )
262
 
263
  steps.extend(
264
  [
265
+ "command -v git >/dev/null 2>&1 || (apt-get update && apt-get install -y --no-install-recommends git && rm -rf /var/lib/apt/lists/*)",
266
+ "pip install --no-cache-dir -q -U huggingface_hub",
267
+ (
268
+ "rm -rf {cd_} && git clone --depth 1 --single-branch "
269
+ "https://oauth2:$HF_TOKEN@huggingface.co/{repo} {cd_}"
270
+ ).format(cd_=clone_dir, repo=args.code_repo),
271
  f"cd {clone_dir}",
272
  extract_chain,
273
+ f"rm -rf {clone_dir}",
274
  ]
275
  )
276
  inner = " && ".join(steps)
 
307
  f"export WJAD_HUB_REPO={args.weights_repo}",
308
  f"export WJAD_DATA_ROOT=/mnt/wjad/{args.extracted_subpath}",
309
  "export WJAD_OUTPUT_DIR=/mnt/wjad/runs/current",
310
+ "command -v git >/dev/null 2>&1 || (apt-get update && apt-get install -y --no-install-recommends git && rm -rf /var/lib/apt/lists/*)",
311
+ "pip install --no-cache-dir -q -U huggingface_hub",
312
+ (
313
+ "rm -rf {cd_} && git clone --depth 1 --single-branch "
314
+ "https://oauth2:$HF_TOKEN@huggingface.co/{repo} {cd_}"
315
+ ).format(cd_=clone_dir, repo=args.code_repo),
316
  f"cd {clone_dir}",
317
+ "pip install --no-cache-dir -q -U pip",
318
+ "pip install --no-cache-dir -q -e .",
319
  "bash scripts/jobs_entry_train.sh",
320
  ]
321
  )
 
340
  return cmd
341
 
342
 
343
+ def _parse_shards_file_line(line: str) -> tuple[str, str | None]:
344
+ """``dir|only=generation,caption`` → (dir, generation,caption) only= 会传给解压脚本。"""
345
+ base = line
346
+ only: str | None = None
347
+ if "|" in line:
348
+ base, flag = line.split("|", 1)
349
+ base = base.strip()
350
+ for piece in flag.split(","):
351
+ piece = piece.strip()
352
+ if piece.startswith("only="):
353
+ val = piece[5:].strip()
354
+ if not val:
355
+ continue
356
+ if not re.fullmatch(r"[A-Za-z0-9_,-]+", val):
357
+ raise ValueError(f"only= 仅允许字母数字、逗号、连字符、下划线: {line!r}")
358
+ only = val
359
+ return base.strip(), only
360
+
361
+
362
+ def _load_shard_groups(args: argparse.Namespace) -> list[tuple[list[str], str | None]]:
363
+ """``--shard-dirs a,b`` → 两个 Job;文件每行一个 Job;行内逗号多目录;可选 ``|only=…``。"""
364
+ groups: list[tuple[list[str], str | None]] = []
365
  if getattr(args, "shard_dirs", None):
366
  for p in args.shard_dirs.split(","):
367
  p = p.strip()
368
  if p:
369
+ groups.append(([p], None))
370
  sf = getattr(args, "shards_file", None)
371
  if sf:
372
  raw = Path(sf).expanduser().read_text(encoding="utf-8")
 
374
  line = line.strip()
375
  if not line or line.startswith("#"):
376
  continue
377
+ base, only = _parse_shards_file_line(line)
378
+ parts = [x.strip() for x in base.split(",") if x.strip()]
379
  if parts:
380
+ groups.append((parts, only))
381
+ seen: set[tuple[tuple[str, ...], str | None]] = set()
382
+ uniq: list[tuple[list[str], str | None]] = []
383
+ for g, o in groups:
384
+ t = (tuple(g), o)
385
  if t not in seen:
386
  seen.add(t)
387
+ uniq.append((g, o))
388
  return uniq
389
 
390
 
391
  def _load_replace_subdirs(args: argparse.Namespace) -> list[str]:
392
+ """从 CLI / 文件收集要预先删除的 extracted 下相对路径可多级)。"""
393
  raw: list[str] = []
394
  rs = getattr(args, "replace_subdirs", None)
395
  if rs:
 
407
  seen: set[str] = set()
408
  out: list[str] = []
409
  for s in raw:
410
+ seg = _safe_rel_subpath(s)
411
  if seg not in seen:
412
  seen.add(seg)
413
  out.append(seg)
 
454
  if _wait_job(jid, label="wipe extracted subdirs") != 0:
455
  return 1
456
  jids: list[tuple[str, str]] = []
457
+ for grp, only_opt in groups:
458
  label = ",".join(grp)
459
  sargs = argparse.Namespace(**vars(args))
460
  sargs.detach = True
461
  sargs.replace_extracted = False
462
+ sargs.extract_only_stems = only_opt
463
  rc, jid = _submit_detach(build_extract_cmd(sargs, shard_group=grp))
464
  if rc != 0:
465
  return rc
 
621
  "--replace-subdirs",
622
  default=None,
623
  dest="replace_subdirs",
624
+ help="extract-parallel:仅 rm -rf 下列路径(逗号分隔,相对 extracted-subpath,可多级如 a/b/c)",
625
  )
626
  p.add_argument(
627
  "--replace-subdirs-file",
628
  default=None,
629
  dest="replace_subdirs_file",
630
+ help="extract-parallel:同上每行相对 extracted-subpath 的条路径(可多级);# 注释",
631
  )
632
  p.add_argument(
633
  "--shard-dirs",
 
639
  "--shards-file",
640
  default=None,
641
  dest="shards_file",
642
+ help="extract-parallel:每行一个 Job;行内逗号多目录;可选后缀 |only=stem1,stem2(仅解压这些归档);# 注释",
643
+ )
644
+ p.add_argument(
645
+ "--extract-only-stems",
646
+ default=None,
647
+ dest="extract_only_stems",
648
+ help="仅 action=extract:传给 jobs_extract_archives --only-stems;parallel 请用 shards 行内 |only=",
649
  )
650
  p.add_argument("--train-image", default=TRAIN_IMAGE, dest="train_image")
651
  p.add_argument("--train-flavor", default=TRAIN_FLAVOR, dest="train_flavor")