File size: 1,759 Bytes
54fa103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""DVC training wrapper — reads params.yaml and invokes train.py.

This script exists because DVC cmd templates don't support conditional
expressions (e.g. ${augment == true && "--augment" || ""}).  The wrapper
reads params.yaml, builds the correct argv, and execs train.py.

Called by dvc.yaml stages.train.  Do not call directly.
"""

import subprocess
import sys
from pathlib import Path

import yaml

REPO_ROOT = Path(__file__).parent.parent
PARAMS_FILE = REPO_ROOT / "params.yaml"
TRAIN_SCRIPT = REPO_ROOT / "train.py"


def main() -> None:
    with open(PARAMS_FILE) as f:
        p = yaml.safe_load(f)["train"]

    cmd = [sys.executable, str(TRAIN_SCRIPT)]

    # Required string / numeric flags
    cmd += ["--benchmark",   str(p["benchmark"])]
    cmd += ["--model",       str(p["model"])]
    cmd += ["--loss",        str(p["loss"])]
    cmd += ["--h1_alpha",    str(p["h1_alpha"])]
    cmd += ["--modes",       str(p["n_modes"])]
    cmd += ["--hidden",      str(p["hidden_dim"])]
    cmd += ["--layers",      str(p["n_layers"])]
    cmd += ["--batch_size",  str(p["batch_size"])]
    cmd += ["--lr",          str(p["lr"])]
    cmd += ["--grad_clip",   str(p["grad_clip"])]
    cmd += ["--budget",      str(p["budget_s"])]

    # Optional boolean flags (action="store_true" in train.py)
    if p.get("augment"):
        cmd.append("--augment")
    if p.get("curriculum"):
        cmd.append("--curriculum")
    if p.get("save_ckpt"):
        cmd.append("--save_ckpt")

    # Optional name
    name = str(p.get("name", "")).strip()
    if name:
        cmd += ["--name", name]

    print(f"[dvc_train] Running: {' '.join(cmd)}", flush=True)
    result = subprocess.run(cmd)
    sys.exit(result.returncode)


if __name__ == "__main__":
    main()