File size: 6,754 Bytes
e458112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# -*- coding: utf-8 -*-
"""Local client for controlling an active AniFileBERT Colab worker.

The worker still has to be started manually in Colab, but once it prints a
public URL and token this client lets Codex submit training jobs, tail logs, and
inspect status from the local workspace.
"""

from __future__ import annotations

import argparse
import json
import os
from pathlib import Path
import sys
import time
from typing import Any
import urllib.error
import urllib.parse
import urllib.request


TERMINAL_STATES = {"success", "failed", "cancelled"}


def load_json(path: str) -> Any:
    return json.loads(Path(path).read_text(encoding="utf-8"))


class ColabClient:
    def __init__(self, base_url: str, token: str, timeout: int = 30):
        self.base_url = base_url.rstrip("/")
        self.token = token
        self.timeout = timeout

    def request(self, method: str, path: str, payload: Any | None = None) -> Any:
        url = self.base_url + path
        data = None
        headers = {"Authorization": f"Bearer {self.token}"}
        if payload is not None:
            data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
            headers["Content-Type"] = "application/json; charset=utf-8"

        req = urllib.request.Request(url, data=data, headers=headers, method=method)
        try:
            with urllib.request.urlopen(req, timeout=self.timeout) as response:
                return json.loads(response.read().decode("utf-8"))
        except urllib.error.HTTPError as exc:
            body = exc.read().decode("utf-8", errors="replace")
            raise RuntimeError(f"{method} {url} failed: HTTP {exc.code}: {body}") from exc

    def health(self) -> Any:
        return self.request("GET", "/health")

    def submit(self, payload: dict[str, Any]) -> Any:
        return self.request("POST", "/jobs", payload)

    def jobs(self) -> Any:
        return self.request("GET", "/jobs")

    def status(self, job_id: str) -> Any:
        return self.request("GET", f"/jobs/{job_id}")

    def logs(self, job_id: str, tail: int) -> Any:
        query = urllib.parse.urlencode({"tail": tail})
        return self.request("GET", f"/jobs/{job_id}/logs?{query}")

    def manifest(self, job_id: str) -> Any:
        return self.request("GET", f"/jobs/{job_id}/manifest")

    def cancel(self, job_id: str) -> Any:
        return self.request("POST", f"/jobs/{job_id}/cancel", {})


def print_json(data: Any) -> None:
    print(json.dumps(data, ensure_ascii=False, indent=2))


def require_connection(args: argparse.Namespace) -> ColabClient:
    url = args.url or os.environ.get("ANIFILEBERT_COLAB_URL")
    token = args.token or os.environ.get("ANIFILEBERT_COLAB_TOKEN")
    if not url or not token:
        raise SystemExit(
            "Set ANIFILEBERT_COLAB_URL and ANIFILEBERT_COLAB_TOKEN, "
            "or pass --url and --token."
        )
    return ColabClient(url, token, timeout=args.timeout)


def build_submit_payload(args: argparse.Namespace) -> dict[str, Any]:
    payload: dict[str, Any] = {}
    if args.config:
        payload["config"] = load_json(args.config)
    if args.profile:
        payload["profile"] = args.profile
    extra_args = list(args.args or []) + list(args.extra_args or [])
    if extra_args:
        payload["args"] = extra_args
    if not payload:
        payload["profile"] = "dmhy_regex_finetune"
    return payload


def wait_for_job(client: ColabClient, job_id: str, poll: int, tail: int) -> dict[str, Any]:
    last_status = None
    while True:
        status = client.status(job_id)
        if status.get("status") != last_status:
            print_json(status)
            last_status = status.get("status")
        logs = client.logs(job_id, tail=tail)
        log_text = logs.get("log", "")
        if log_text:
            print("\n--- log tail ---")
            print(log_text.rstrip())
        if status.get("status") in TERMINAL_STATES:
            return status
        time.sleep(poll)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Control an active AniFileBERT Colab worker")
    parser.add_argument("--url", help="Worker URL, or ANIFILEBERT_COLAB_URL")
    parser.add_argument("--token", help="Worker token, or ANIFILEBERT_COLAB_TOKEN")
    parser.add_argument("--timeout", type=int, default=30)

    subparsers = parser.add_subparsers(dest="command", required=True)

    subparsers.add_parser("health", help="Check worker health")
    subparsers.add_parser("jobs", help="List known jobs")

    submit = subparsers.add_parser("submit", help="Submit a training job")
    submit.add_argument("--config", help="Local JSON config to send to the worker")
    submit.add_argument("--profile", help="Remote profile name under colab/configs")
    submit.add_argument("--arg", dest="args", action="append", default=[], help="Extra arg for colab_train.py")
    submit.add_argument("--wait", action="store_true", help="Poll until the job finishes")
    submit.add_argument("--poll", type=int, default=60, help="Polling interval in seconds")
    submit.add_argument("--tail", type=int, default=80, help="Log lines to show while waiting")
    submit.add_argument("extra_args", nargs=argparse.REMAINDER,
                        help="Arguments after -- are passed to colab_train.py")

    status = subparsers.add_parser("status", help="Show job status")
    status.add_argument("job_id")

    logs = subparsers.add_parser("logs", help="Show job logs")
    logs.add_argument("job_id")
    logs.add_argument("--tail", type=int, default=200)

    manifest = subparsers.add_parser("manifest", help="Show job manifest")
    manifest.add_argument("job_id")

    cancel = subparsers.add_parser("cancel", help="Cancel a running job")
    cancel.add_argument("job_id")

    return parser.parse_args()


def main() -> None:
    args = parse_args()
    client = require_connection(args)

    if args.command == "health":
        print_json(client.health())
    elif args.command == "jobs":
        print_json(client.jobs())
    elif args.command == "submit":
        job = client.submit(build_submit_payload(args))
        print_json(job)
        if args.wait:
            final_status = wait_for_job(client, job["job_id"], poll=args.poll, tail=args.tail)
            if final_status.get("status") != "success":
                sys.exit(1)
    elif args.command == "status":
        print_json(client.status(args.job_id))
    elif args.command == "logs":
        print(client.logs(args.job_id, args.tail).get("log", ""), end="")
    elif args.command == "manifest":
        print_json(client.manifest(args.job_id))
    elif args.command == "cancel":
        print_json(client.cancel(args.job_id))


if __name__ == "__main__":
    main()