File size: 10,888 Bytes
08ff31f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
#!/usr/bin/env python3
"""Run a sweep of variable-speed ablation experiments end-to-end.

For each ablation: build dataset (multi-process) -> compute norm stats -> train.
All three stages share one TrainConfig name; per-ablation differences are
applied via CLI overrides on ``--data.repo_id`` and ``--data.assets.asset_id``,
and via ``--repo-id`` / ``--asset-id`` for compute_norm_stats.

Edit the ABLATIONS table below to define your groups. Use:
- ``--only NAME[,NAME]``         to run a subset
- ``--skip-build / --skip-norm-stats / --skip-train`` to bypass stages
- ``--dry-run``                  to print commands without executing them

Example:

    uv run python scripts/run_ablations.py \\
      --src $SRC --out-root $ROOT \\
      --train-config pi05_libero_various_speed_all \\
      --base-asset-id libero_various_speed_all_pi05 \\
      --exp-prefix pi05_ablation \\
      --num-train-steps 30000 \\
      --build-num-workers 16 --train-num-workers 8 \\
      --only g2_coarse,g3a_step025 --dry-run
"""
from __future__ import annotations

import argparse
import dataclasses
import shlex
import subprocess
import sys
from pathlib import Path


@dataclasses.dataclass(frozen=True)
class Ablation:
    name: str
    speeds: tuple[float, ...]
    # How speed is fed to the model: "text" | "modulation" | "soft_prompt" | "auto".
    # See LeRobotVariousSpeedLiberoDataConfig.speed_integration for semantics.
    speed_integration: str = "auto"
    # Extra args appended verbatim to the train_pytorch.py invocation. Use this
    # for per-ablation model overrides (e.g., speed_modulation for modulation).
    extra_train_args: tuple[str, ...] = ()
    # When set, multiple ablations sharing the same key share one norm_stats
    # file (and thus one ``asset_id``). Norm stats only depend on the dataset's
    # state/actions, so a P-length sweep over the same speed set should reuse
    # one set of stats. Default ``None`` -> isolate per ablation (legacy).
    shared_norm_key: str | None = None


# Speeds shared by the speed-integration and soft-prompt-P sweeps (all train
# on the same dataset; only the model conditioning differs).
_SPEED_INT_SPEEDS: tuple[float, ...] = (0.75, 1.0, 1.25, 1.5)
_SOFT_PROMPT_P_VALUES: tuple[int, ...] = (1, 4, 8, 16, 32)


def _soft_prompt_ablation(p: int) -> Ablation:
    # tyro requires tuple[float, ...] CLI args to be space-separated argv
    # elements, NOT comma-joined. Mirror the --eval-speed-set pattern used
    # by train_cmd below.
    return Ablation(
        f"softprompt_p{p}",
        _SPEED_INT_SPEEDS,
        speed_integration="soft_prompt",
        extra_train_args=(
            f"--model.soft-prompt-p={p}",
            "--model.soft-prompt-speeds",
            *(f"{s:g}" for s in _SPEED_INT_SPEEDS),
        ),
        # All P-variants share one norm_stats: norm depends only on dataset.
        shared_norm_key="softprompt_shared",
    )


# Edit this table to define your ablation groups. Names show up in dataset
# directory names, asset_id suffixes, and exp_name suffixes, so keep them short
# and unique. Multiple ablations with identical ``speeds`` share one built
# dataset; ablations with the same ``shared_norm_key`` share norm_stats.
ABLATIONS: tuple[Ablation, ...] = (
    # Speed-set sweep: range/step-size effects, all use textual prompt.
    Ablation("g1_baseline", (1.0,), speed_integration="text"),
    Ablation("g2_coarse", (0.5, 1.0, 1.5, 2.0), speed_integration="text"),
    Ablation("g3a_step025", (0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0), speed_integration="text"),
    Ablation("g4_narrow", (0.75, 1.0, 1.25, 1.5), speed_integration="text"),
    Ablation("g5_extreme", (0.25, 0.5, 1.0, 2.0, 4.0), speed_integration="text"),
    # Speed-integration sweep: same data, three conditioning strategies. The
    # soft_prompt arm reuses ``softprompt_p8`` from the P-sweep below.
    Ablation("speedint_text", _SPEED_INT_SPEEDS, speed_integration="text"),
    Ablation(
        "speedint_modulation",
        _SPEED_INT_SPEEDS,
        speed_integration="modulation",
        extra_train_args=("--model.speed-modulation=True",),
    ),
    # Soft-prompt P-length sweep. P=8 doubles as the soft_prompt arm of the
    # speed-integration sweep above. All five share one dataset and one
    # norm_stats file (shared_norm_key="softprompt_shared").
    *(_soft_prompt_ablation(p) for p in _SOFT_PROMPT_P_VALUES),
)


def _speed_token(speeds: tuple[float, ...]) -> str:
    return "_".join(f"{s:g}".replace(".", "p") for s in speeds)


def _dataset_dir(out_root: Path, ablation: Ablation, run_tag: str) -> Path:
    return out_root / f"libero_speed_{_speed_token(ablation.speeds)}_{run_tag}"


def _asset_id(base_asset_id: str, ablation: Ablation) -> str:
    suffix = ablation.shared_norm_key or ablation.name
    return f"{base_asset_id}_{suffix}"


def _exp_name(prefix: str, ablation: Ablation) -> str:
    return f"{prefix}_{ablation.name}"


def _run(cmd: list[str], dry_run: bool) -> int:
    print(f"$ {' '.join(shlex.quote(c) for c in cmd)}")
    if dry_run:
        return 0
    return subprocess.call(cmd)


def build_cmd(
    ablation: Ablation,
    src: Path,
    out_root: Path,
    run_tag: str,
    num_workers: int,
    clean_eps: float,
) -> tuple[Path, list[str]]:
    out_dir = _dataset_dir(out_root, ablation, run_tag)
    cmd = [
        "uv", "run", "python", "scripts/build_libero_speed_dataset_mp.py",
        "--src", str(src),
        "--dst", str(out_dir),
        "--speeds", *(f"{s:g}" for s in ablation.speeds),
        "--clean-transl-eps", str(clean_eps),
        "--clean-rot-eps", str(clean_eps),
        "--min-segment-len", "1",
        "--num-workers", str(num_workers),
        "--overwrite",
    ]
    return out_dir, cmd


def norm_stats_cmd(
    ablation: Ablation,
    dataset_dir: Path,
    train_config_name: str,
    base_asset_id: str,
) -> list[str]:
    return [
        "uv", "run", "python", "scripts/compute_norm_stats.py",
        train_config_name,
        "--repo-id", str(dataset_dir),
        "--asset-id", _asset_id(base_asset_id, ablation),
    ]


def train_cmd(
    ablation: Ablation,
    dataset_dir: Path,
    train_config_name: str,
    base_asset_id: str,
    exp_prefix: str,
    num_workers: int,
    num_train_steps: int,
) -> list[str]:
    cmd = [
        "uv", "run", "python", "scripts/train_pytorch.py", train_config_name,
        f"--exp-name={_exp_name(exp_prefix, ablation)}",
        f"--data.repo-id={dataset_dir}",
        f"--data.assets.asset-id={_asset_id(base_asset_id, ablation)}",
        f"--data.speed-integration={ablation.speed_integration}",
        f"--num-workers={num_workers}",
        f"--num-train-steps={num_train_steps}",
        "--overwrite",
        "--eval-speed-set",
        *(f"{s:g}" for s in ablation.speeds),
    ]
    cmd.extend(ablation.extra_train_args)
    return cmd


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description=__doc__)
    p.add_argument("--src", required=True, help="Source LeRobot dataset root")
    p.add_argument("--out-root", required=True, help="Root directory under which built datasets land")
    p.add_argument("--train-config", default="pi05_libero_various_speed_all")
    p.add_argument("--base-asset-id", default="libero_various_speed_all_pi05")
    p.add_argument("--exp-prefix", default="ablation")
    p.add_argument("--run-tag", default="ablation")
    p.add_argument(
        "--only",
        default=None,
        help="Comma-separated list of ablation names. Default = all in ABLATIONS.",
    )
    p.add_argument("--skip-build", action="store_true")
    p.add_argument("--skip-norm-stats", action="store_true")
    p.add_argument("--skip-train", action="store_true")
    p.add_argument("--build-num-workers", type=int, default=16)
    p.add_argument("--train-num-workers", type=int, default=8)
    p.add_argument("--num-train-steps", type=int, default=30_000)
    # Default 0.0 disables near-zero action cleaning (LIBERO is clean enough).
    p.add_argument("--clean-eps", type=float, default=0.0)
    p.add_argument("--dry-run", action="store_true")
    return p.parse_args()


def main() -> None:
    args = parse_args()
    src = Path(args.src).resolve()
    out_root = Path(args.out_root).resolve()

    only = {s.strip() for s in args.only.split(",")} if args.only else None
    selected = [a for a in ABLATIONS if not only or a.name in only]
    if not selected:
        sys.exit(
            f"No ablations selected. Names in --only must be one of: {[a.name for a in ABLATIONS]}"
        )
    if only:
        unknown = only - {a.name for a in ABLATIONS}
        if unknown:
            sys.exit(f"Unknown ablation names: {sorted(unknown)}")

    print(f"selected   = {[a.name for a in selected]}")
    print(f"src        = {src}")
    print(f"out_root   = {out_root}")
    print(f"train_cfg  = {args.train_config}")
    print(f"asset_base = {args.base_asset_id}")
    print(f"exp_prefix = {args.exp_prefix}")
    print()

    built_speeds: set[str] = set()
    norm_done_asset_ids: set[str] = set()
    for ab in selected:
        print(f"\n========== ablation: {ab.name}  speeds={ab.speeds}  speed_integration={ab.speed_integration} ==========")
        ds_dir = _dataset_dir(out_root, ab, args.run_tag)
        speed_tok = _speed_token(ab.speeds)

        if not args.skip_build:
            if speed_tok in built_speeds:
                print(f"[skip build, dataset already built in this run] dataset_dir = {ds_dir}")
            else:
                _, cmd = build_cmd(ab, src, out_root, args.run_tag, args.build_num_workers, args.clean_eps)
                rc = _run(cmd, args.dry_run)
                if rc != 0:
                    sys.exit(f"build failed for {ab.name} (rc={rc})")
                built_speeds.add(speed_tok)
        else:
            print(f"[skip build] dataset_dir = {ds_dir}")

        if not args.skip_norm_stats:
            asset_id = _asset_id(args.base_asset_id, ab)
            if asset_id in norm_done_asset_ids:
                print(f"[skip norm-stats, already computed in this run for asset_id={asset_id}]")
            else:
                cmd = norm_stats_cmd(ab, ds_dir, args.train_config, args.base_asset_id)
                rc = _run(cmd, args.dry_run)
                if rc != 0:
                    sys.exit(f"norm-stats failed for {ab.name} (rc={rc})")
                norm_done_asset_ids.add(asset_id)

        if not args.skip_train:
            cmd = train_cmd(
                ab,
                ds_dir,
                args.train_config,
                args.base_asset_id,
                args.exp_prefix,
                args.train_num_workers,
                args.num_train_steps,
            )
            rc = _run(cmd, args.dry_run)
            if rc != 0:
                sys.exit(f"train failed for {ab.name} (rc={rc})")


if __name__ == "__main__":
    main()