AniFileBERT / colab_client.py
ModerRAS's picture
Add Codex Colab training workflow
e458112
raw
history blame
6.75 kB
# -*- coding: utf-8 -*-
"""Local client for controlling an active AniFileBERT Colab worker.
The worker still has to be started manually in Colab, but once it prints a
public URL and token this client lets Codex submit training jobs, tail logs, and
inspect status from the local workspace.
"""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
import sys
import time
from typing import Any
import urllib.error
import urllib.parse
import urllib.request
TERMINAL_STATES = {"success", "failed", "cancelled"}
def load_json(path: str) -> Any:
return json.loads(Path(path).read_text(encoding="utf-8"))
class ColabClient:
def __init__(self, base_url: str, token: str, timeout: int = 30):
self.base_url = base_url.rstrip("/")
self.token = token
self.timeout = timeout
def request(self, method: str, path: str, payload: Any | None = None) -> Any:
url = self.base_url + path
data = None
headers = {"Authorization": f"Bearer {self.token}"}
if payload is not None:
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
headers["Content-Type"] = "application/json; charset=utf-8"
req = urllib.request.Request(url, data=data, headers=headers, method=method)
try:
with urllib.request.urlopen(req, timeout=self.timeout) as response:
return json.loads(response.read().decode("utf-8"))
except urllib.error.HTTPError as exc:
body = exc.read().decode("utf-8", errors="replace")
raise RuntimeError(f"{method} {url} failed: HTTP {exc.code}: {body}") from exc
def health(self) -> Any:
return self.request("GET", "/health")
def submit(self, payload: dict[str, Any]) -> Any:
return self.request("POST", "/jobs", payload)
def jobs(self) -> Any:
return self.request("GET", "/jobs")
def status(self, job_id: str) -> Any:
return self.request("GET", f"/jobs/{job_id}")
def logs(self, job_id: str, tail: int) -> Any:
query = urllib.parse.urlencode({"tail": tail})
return self.request("GET", f"/jobs/{job_id}/logs?{query}")
def manifest(self, job_id: str) -> Any:
return self.request("GET", f"/jobs/{job_id}/manifest")
def cancel(self, job_id: str) -> Any:
return self.request("POST", f"/jobs/{job_id}/cancel", {})
def print_json(data: Any) -> None:
print(json.dumps(data, ensure_ascii=False, indent=2))
def require_connection(args: argparse.Namespace) -> ColabClient:
url = args.url or os.environ.get("ANIFILEBERT_COLAB_URL")
token = args.token or os.environ.get("ANIFILEBERT_COLAB_TOKEN")
if not url or not token:
raise SystemExit(
"Set ANIFILEBERT_COLAB_URL and ANIFILEBERT_COLAB_TOKEN, "
"or pass --url and --token."
)
return ColabClient(url, token, timeout=args.timeout)
def build_submit_payload(args: argparse.Namespace) -> dict[str, Any]:
payload: dict[str, Any] = {}
if args.config:
payload["config"] = load_json(args.config)
if args.profile:
payload["profile"] = args.profile
extra_args = list(args.args or []) + list(args.extra_args or [])
if extra_args:
payload["args"] = extra_args
if not payload:
payload["profile"] = "dmhy_regex_finetune"
return payload
def wait_for_job(client: ColabClient, job_id: str, poll: int, tail: int) -> dict[str, Any]:
last_status = None
while True:
status = client.status(job_id)
if status.get("status") != last_status:
print_json(status)
last_status = status.get("status")
logs = client.logs(job_id, tail=tail)
log_text = logs.get("log", "")
if log_text:
print("\n--- log tail ---")
print(log_text.rstrip())
if status.get("status") in TERMINAL_STATES:
return status
time.sleep(poll)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Control an active AniFileBERT Colab worker")
parser.add_argument("--url", help="Worker URL, or ANIFILEBERT_COLAB_URL")
parser.add_argument("--token", help="Worker token, or ANIFILEBERT_COLAB_TOKEN")
parser.add_argument("--timeout", type=int, default=30)
subparsers = parser.add_subparsers(dest="command", required=True)
subparsers.add_parser("health", help="Check worker health")
subparsers.add_parser("jobs", help="List known jobs")
submit = subparsers.add_parser("submit", help="Submit a training job")
submit.add_argument("--config", help="Local JSON config to send to the worker")
submit.add_argument("--profile", help="Remote profile name under colab/configs")
submit.add_argument("--arg", dest="args", action="append", default=[], help="Extra arg for colab_train.py")
submit.add_argument("--wait", action="store_true", help="Poll until the job finishes")
submit.add_argument("--poll", type=int, default=60, help="Polling interval in seconds")
submit.add_argument("--tail", type=int, default=80, help="Log lines to show while waiting")
submit.add_argument("extra_args", nargs=argparse.REMAINDER,
help="Arguments after -- are passed to colab_train.py")
status = subparsers.add_parser("status", help="Show job status")
status.add_argument("job_id")
logs = subparsers.add_parser("logs", help="Show job logs")
logs.add_argument("job_id")
logs.add_argument("--tail", type=int, default=200)
manifest = subparsers.add_parser("manifest", help="Show job manifest")
manifest.add_argument("job_id")
cancel = subparsers.add_parser("cancel", help="Cancel a running job")
cancel.add_argument("job_id")
return parser.parse_args()
def main() -> None:
args = parse_args()
client = require_connection(args)
if args.command == "health":
print_json(client.health())
elif args.command == "jobs":
print_json(client.jobs())
elif args.command == "submit":
job = client.submit(build_submit_payload(args))
print_json(job)
if args.wait:
final_status = wait_for_job(client, job["job_id"], poll=args.poll, tail=args.tail)
if final_status.get("status") != "success":
sys.exit(1)
elif args.command == "status":
print_json(client.status(args.job_id))
elif args.command == "logs":
print(client.logs(args.job_id, args.tail).get("log", ""), end="")
elif args.command == "manifest":
print_json(client.manifest(args.job_id))
elif args.command == "cancel":
print_json(client.cancel(args.job_id))
if __name__ == "__main__":
main()