File size: 10,733 Bytes
d1bab46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d95a287
 
 
ac0940b
d1bab46
 
d95a287
d1bab46
d95a287
 
 
 
 
 
 
 
d1bab46
 
 
724c9e9
 
 
d1bab46
 
 
 
724c9e9
 
 
d1bab46
 
 
12e9bac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac0940b
 
 
 
 
d95a287
ac0940b
 
 
 
d95a287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724c9e9
 
 
 
 
 
 
d95a287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724c9e9
 
d95a287
 
 
 
 
 
 
 
 
 
 
 
 
 
ac0940b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12e9bac
 
 
 
 
 
 
 
 
 
 
 
 
 
ac0940b
 
 
 
 
 
 
 
 
 
 
d1bab46
 
 
 
 
 
 
 
 
 
 
ac0940b
d95a287
 
 
 
 
 
 
 
d1bab46
 
 
 
ac0940b
d1bab46
 
 
 
 
 
ac0940b
d1bab46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724c9e9
 
d1bab46
 
724c9e9
d1bab46
 
 
724c9e9
d1bab46
 
 
724c9e9
d1bab46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724c9e9
d1bab46
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
#!/usr/bin/env python3
"""Dispatch CheXVision training kernels to Kaggle and manage their lifecycle.

Usage:
    python scripts/dispatch.py kaggle scratch          # push & run the scratch kernel
    python scripts/dispatch.py kaggle transfer         # push & run the transfer kernel
    python scripts/dispatch.py kaggle status scratch   # check kernel status
    python scripts/dispatch.py kaggle output scratch   # download kernel output

Requires the Kaggle CLI: pip install kaggle
"""

from __future__ import annotations

import argparse
import base64
import io
import json
import os
import subprocess
import sys
import zipfile
from pathlib import Path
from shutil import rmtree, which

PROJECT_ROOT = Path(__file__).resolve().parent.parent
BUNDLE_ROOT = PROJECT_ROOT / ".codex_tmp" / "kaggle"
BUNDLE_PATHS = (Path("src"), Path("configs"))
BUNDLE_SENTINEL = "__CHEXVISION_PROJECT_BUNDLE_B64__"
EXCLUDED_BUNDLE_DIRS = {"__pycache__", ".pytest_cache"}
EXCLUDED_BUNDLE_SUFFIXES = {".pyc", ".pyo"}

# Map short model names to kernel directory paths (relative to repo root).
KERNEL_DIRS = {
    "scratch":    Path("kaggle/train_scratch"),
    "transfer":   Path("kaggle/train_transfer"),
    "resize_320": Path("kaggle/resize_320"),
}

# Kaggle kernel slugs (must match the "id" in kernel-metadata.json).
KERNEL_SLUGS = {
    "scratch":    "hlexnc/chexvision-train-scratch-cnn",
    "transfer":   "hlexnc/chexvision-train-densenet-transfer",
    "resize_320": "hlexnc/chexvision-resize-320",
}


def _get_kaggle_version() -> tuple[int, ...] | None:
    """Return the installed Kaggle CLI version as a tuple when available."""
    if which("kaggle") is None:
        return None

    result = subprocess.run(
        ["kaggle", "--version"],
        capture_output=True,
        text=True,
        check=False,
    )
    if result.returncode != 0:
        return None

    output = (result.stdout or result.stderr).strip()
    prefix = "Kaggle API "
    if output.startswith(prefix):
        output = output[len(prefix):]

    try:
        return tuple(int(part) for part in output.split("."))
    except ValueError:
        return None


def _load_env() -> None:
    """Load environment variables from the project .env when available."""
    try:
        from dotenv import load_dotenv

        load_dotenv(PROJECT_ROOT / ".env")
    except ImportError:
        return


def _build_kaggle_bundle(model: str) -> Path:
    """Create a self-contained bundle for Kaggle to run remotely.

    Kaggle script pushes only keep the main code file, so we render a temporary
    script with the project source embedded as a base64 zip payload.
    """
    kernel_dir = PROJECT_ROOT / KERNEL_DIRS[model]
    if not kernel_dir.exists():
        print(f"ERROR: Kernel directory not found: {kernel_dir}", file=sys.stderr)
        sys.exit(1)

    bundle_dir = BUNDLE_ROOT / model
    if bundle_dir.exists():
        rmtree(bundle_dir)
    bundle_dir.mkdir(parents=True, exist_ok=True)

    archive_buffer = io.BytesIO()
    with zipfile.ZipFile(archive_buffer, "w", compression=zipfile.ZIP_DEFLATED) as archive:
        for rel_path in BUNDLE_PATHS:
            source_root = PROJECT_ROOT / rel_path
            for path in source_root.rglob("*"):
                if path.is_file() and _should_bundle_path(path):
                    archive.write(path, arcname=path.relative_to(PROJECT_ROOT).as_posix())

    script_template = (kernel_dir / "script.py").read_text(encoding="utf-8")
    if BUNDLE_SENTINEL in script_template:
        bundle_b64 = base64.b64encode(archive_buffer.getvalue()).decode("ascii")
        rendered_script = script_template.replace(BUNDLE_SENTINEL, bundle_b64, 1)
        (bundle_dir / "script.py").write_text(rendered_script, encoding="utf-8")
    else:
        # Self-contained script (e.g. resize_320) — no bundle injection needed.
        (bundle_dir / "script.py").write_text(script_template, encoding="utf-8")
    metadata = _render_kernel_metadata(model)
    (bundle_dir / "kernel-metadata.json").write_text(
        json.dumps(metadata, indent=2) + "\n",
        encoding="utf-8",
    )
    return bundle_dir


def _render_kernel_metadata(model: str) -> dict[str, object]:
    """Render the bundle metadata with the required Kaggle runtime flags."""
    metadata_path = PROJECT_ROOT / KERNEL_DIRS[model] / "kernel-metadata.json"
    metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
    metadata["id"] = KERNEL_SLUGS[model]
    metadata["code_file"] = "script.py"
    metadata["language"] = "python"
    metadata["kernel_type"] = "script"
    # enable_gpu comes from kernel-metadata.json; training kernels set it true,
    # CPU-only kernels (e.g. resize_320) set it false — don't override here.
    metadata["enable_internet"] = True
    return metadata


def _should_bundle_path(path: Path) -> bool:
    """Filter out local cache/build artefacts from the Kaggle source bundle."""
    if path.suffix in EXCLUDED_BUNDLE_SUFFIXES:
        return False
    for part in path.parts:
        if part in EXCLUDED_BUNDLE_DIRS or part.endswith(".egg-info"):
            return False
    return True


def _ensure_kaggle_auth(model: str) -> None:
    """Map repo-local Kaggle credentials into the variables the CLI expects."""
    _load_env()

    if os.environ.get("KAGGLE_USERNAME") and os.environ.get("KAGGLE_KEY"):
        return

    api_token = os.environ.get("KAGGLE_API_TOKEN", "").strip()
    if not api_token:
        print(
            "ERROR: Kaggle credentials not found. Set KAGGLE_USERNAME/KAGGLE_KEY "
            "or provide KAGGLE_API_TOKEN in .env.",
            file=sys.stderr,
        )
        sys.exit(1)

    # Newer Kaggle personal access tokens look like KGAT_... and are handled
    # directly by newer Kaggle CLI releases without a username split.
    if api_token.startswith("KGAT_"):
        version = _get_kaggle_version()
        if version is not None and version < (1, 8, 0):
            print(
                "ERROR: Detected a newer Kaggle API token (KGAT_...), but the "
                f"installed Kaggle CLI is {'.'.join(map(str, version))}. "
                "Upgrade Kaggle CLI to >= 1.8.0 or use kagglehub >= 0.4.1.",
                file=sys.stderr,
            )
            sys.exit(1)
        return

    if ":" in api_token:
        username, key = api_token.split(":", 1)
        os.environ.setdefault("KAGGLE_USERNAME", username)
        os.environ.setdefault("KAGGLE_KEY", key)
        return

    owner = KERNEL_SLUGS[model].split("/", 1)[0]
    os.environ.setdefault("KAGGLE_USERNAME", owner)
    os.environ.setdefault("KAGGLE_KEY", api_token)


def _run(cmd: list[str]) -> None:
    """Run a subprocess and stream its output."""
    print(f"$ {' '.join(cmd)}")
    result = subprocess.run(cmd, capture_output=False)
    if result.returncode != 0:
        print(f"Command exited with code {result.returncode}", file=sys.stderr)
        sys.exit(result.returncode)


def cmd_push(model: str) -> None:
    """Push a kernel folder to Kaggle (triggers a new run)."""
    _ensure_kaggle_auth(model)
    bundle_dir = _build_kaggle_bundle(model)
    print(
        "NOTE: Kaggle runs remotely and will not inherit local .env values. "
        "Add HF_TOKEN in Kaggle Secrets for authenticated HF dataset access "
        "and automatic model uploads. Dispatch bundles always force Kaggle "
        "internet and GPU on for training kernels."
    )
    _run(["kaggle", "kernels", "push", "-p", str(bundle_dir)])


def cmd_status(model: str) -> None:
    """Check the current status of a Kaggle kernel."""
    _ensure_kaggle_auth(model)
    slug = KERNEL_SLUGS[model]
    _run(["kaggle", "kernels", "status", slug])


def cmd_output(model: str) -> None:
    """Download the output files of a completed Kaggle kernel."""
    _ensure_kaggle_auth(model)
    slug = KERNEL_SLUGS[model]
    out_dir = Path(f"kaggle_output/{model}")
    out_dir.mkdir(parents=True, exist_ok=True)
    _run(["kaggle", "kernels", "output", slug, "-p", str(out_dir)])
    print(f"Output saved to {out_dir.resolve()}")


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Dispatch CheXVision training to Kaggle."
    )
    subparsers = parser.add_subparsers(dest="platform", help="Target platform")

    # --- kaggle sub-command ---------------------------------------------------
    kaggle_parser = subparsers.add_parser("kaggle", help="Kaggle kernel operations")
    kaggle_sub = kaggle_parser.add_subparsers(dest="action", help="Action to perform")

    _all_kernels = ["scratch", "transfer", "resize_320"]

    # kaggle push (default when just model name given)
    push_parser = kaggle_sub.add_parser("push", help="Push kernel to Kaggle")
    push_parser.add_argument("model", choices=_all_kernels)

    # kaggle status
    status_parser = kaggle_sub.add_parser("status", help="Check kernel status")
    status_parser.add_argument("model", choices=_all_kernels)

    # kaggle output
    output_parser = kaggle_sub.add_parser("output", help="Download kernel output")
    output_parser.add_argument("model", choices=_all_kernels)

    args = parser.parse_args()

    if args.platform is None:
        parser.print_help()
        sys.exit(1)

    if args.platform == "kaggle":
        # Allow shorthand: `dispatch.py kaggle scratch` == `dispatch.py kaggle push scratch`
        if args.action is None:
            parser.print_help()
            sys.exit(1)

        if args.action == "push":
            cmd_push(args.model)
        elif args.action == "status":
            cmd_status(args.model)
        elif args.action == "output":
            cmd_output(args.model)
        else:
            # Handle the shorthand case where action IS the model name
            kaggle_parser.print_help()
            sys.exit(1)


# ---------------------------------------------------------------------------
# Support the shorthand syntax from the docstring:
#   python scripts/dispatch.py kaggle scratch
#   python scripts/dispatch.py kaggle status scratch
#
# argparse subcommands alone can't handle both forms, so we do a small
# pre-processing step on sys.argv before parsing.
# ---------------------------------------------------------------------------

def _preprocess_argv() -> None:
    """Rewrite argv so that `kaggle <model>` becomes `kaggle push <model>`."""
    model_names = {"scratch", "transfer", "resize_320"}
    # Pattern: script kaggle <model>  (3 args after script name, 2nd is kaggle, 3rd is model)
    if len(sys.argv) >= 3 and sys.argv[1] == "kaggle" and sys.argv[2] in model_names:
        # Insert "push" between "kaggle" and the model name
        sys.argv.insert(2, "push")


if __name__ == "__main__":
    _preprocess_argv()
    main()