AniFileBERT / colab_worker.py
ModerRAS's picture
Add Codex Colab training workflow
e458112
raw
history blame
16.9 kB
# -*- coding: utf-8 -*-
"""Small HTTP worker for running AniFileBERT training jobs on Google Colab.
Start this inside a Colab runtime:
python colab_worker.py
The worker exposes a token-protected local HTTP API and, by default, starts a
Cloudflare Quick Tunnel so Codex on your local machine can submit jobs.
"""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
import platform
import re
import secrets
import shutil
import signal
import subprocess
import sys
import threading
import time
import traceback
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from typing import Any
from urllib.parse import parse_qs, urlparse
import urllib.request
TERMINAL_STATES = {"success", "failed", "cancelled"}
TUNNEL_URL_RE = re.compile(r"https://[-a-zA-Z0-9.]+\.trycloudflare\.com")
def utc_timestamp() -> str:
return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
def json_dumps(data: Any) -> str:
return json.dumps(data, ensure_ascii=False, indent=2)
def read_tail(path: Path, lines: int) -> str:
if not path.is_file():
return ""
if lines <= 0:
return path.read_text(encoding="utf-8", errors="replace")
chunk_size = 8192
data = b""
with path.open("rb") as f:
f.seek(0, os.SEEK_END)
pos = f.tell()
while pos > 0 and data.count(b"\n") <= lines:
read_size = min(chunk_size, pos)
pos -= read_size
f.seek(pos)
data = f.read(read_size) + data
return b"\n".join(data.splitlines()[-lines:]).decode("utf-8", errors="replace")
def download_cloudflared(path: Path) -> Path:
if path.is_file():
return path
existing = shutil.which("cloudflared")
if existing:
return Path(existing)
arch = platform.machine().lower()
if arch in {"x86_64", "amd64"}:
suffix = "linux-amd64"
elif arch in {"aarch64", "arm64"}:
suffix = "linux-arm64"
else:
raise RuntimeError(f"Unsupported CPU architecture for cloudflared: {arch}")
url = f"https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-{suffix}"
print(f"Downloading cloudflared: {url}", flush=True)
path.parent.mkdir(parents=True, exist_ok=True)
urllib.request.urlretrieve(url, path)
path.chmod(0o755)
return path
class WorkerState:
def __init__(self, repo_dir: Path, jobs_dir: Path):
self.repo_dir = repo_dir
self.jobs_dir = jobs_dir
self.jobs_dir.mkdir(parents=True, exist_ok=True)
self.jobs: dict[str, dict[str, Any]] = {}
self.lock = threading.RLock()
def list_jobs(self) -> list[dict[str, Any]]:
with self.lock:
return [self._public_job(job) for job in self.jobs.values()]
def get_job(self, job_id: str) -> dict[str, Any] | None:
with self.lock:
job = self.jobs.get(job_id)
return self._public_job(job) if job else None
def get_job_internal(self, job_id: str) -> dict[str, Any] | None:
with self.lock:
return self.jobs.get(job_id)
def active_job(self) -> dict[str, Any] | None:
with self.lock:
for job in self.jobs.values():
if job["status"] not in TERMINAL_STATES:
return job
return None
def start_job(self, payload: dict[str, Any]) -> dict[str, Any]:
with self.lock:
active = self.active_job()
if active is not None:
raise RuntimeError(f"Job already running: {active['job_id']}")
job_id = time.strftime("%Y%m%d-%H%M%S", time.gmtime()) + "-" + secrets.token_hex(3)
job_dir = self.jobs_dir / job_id
job_dir.mkdir(parents=True, exist_ok=True)
log_path = job_dir / "worker.log"
config_path: Path | None = None
cmd = [sys.executable, "colab_train.py"]
config = self._job_config(payload)
config.setdefault("artifacts", {})
config["artifacts"]["manifest"] = os.fspath(job_dir / "colab_run_manifest.json")
config_path = job_dir / "config.json"
config_path.write_text(json_dumps(config), encoding="utf-8")
cmd.extend(["--config", os.fspath(config_path)])
for arg in payload.get("args", []):
cmd.append(str(arg))
job = {
"job_id": job_id,
"status": "queued",
"created_at": utc_timestamp(),
"started_at": None,
"finished_at": None,
"returncode": None,
"cmd": cmd,
"cwd": os.fspath(self.repo_dir),
"job_dir": os.fspath(job_dir),
"log_path": os.fspath(log_path),
"config_path": os.fspath(config_path) if config_path else None,
"error": None,
"process": None,
}
self.jobs[job_id] = job
thread = threading.Thread(target=self._run_job, args=(job_id,), daemon=True)
thread.start()
return self._public_job(job)
def _job_config(self, payload: dict[str, Any]) -> dict[str, Any]:
if "config" in payload:
return json.loads(json.dumps(payload["config"], ensure_ascii=False))
profile = str(payload.get("profile", "dmhy_regex_finetune"))
profile_path = self.repo_dir / "colab" / "configs" / f"{profile}.json"
if not profile_path.is_file():
raise FileNotFoundError(f"Profile not found: {profile_path}")
return json.loads(profile_path.read_text(encoding="utf-8"))
def cancel_job(self, job_id: str) -> dict[str, Any]:
with self.lock:
job = self.jobs.get(job_id)
if job is None:
raise KeyError(job_id)
process: subprocess.Popen[str] | None = job.get("process")
if job["status"] in TERMINAL_STATES:
return self._public_job(job)
job["status"] = "cancelled"
job["finished_at"] = utc_timestamp()
if process and process.poll() is None:
try:
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
except Exception:
process.terminate()
return self.get_job(job_id) or {}
def _run_job(self, job_id: str) -> None:
job = self.get_job_internal(job_id)
if job is None:
return
log_path = Path(job["log_path"])
try:
with self.lock:
job["status"] = "running"
job["started_at"] = utc_timestamp()
with log_path.open("w", encoding="utf-8", errors="replace") as log:
log.write(f"job_id={job_id}\n")
log.write(f"cwd={job['cwd']}\n")
log.write("$ " + " ".join(job["cmd"]) + "\n\n")
log.flush()
process = subprocess.Popen(
job["cmd"],
cwd=job["cwd"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
encoding="utf-8",
errors="replace",
bufsize=1,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
)
with self.lock:
job["process"] = process
assert process.stdout is not None
for line in process.stdout:
log.write(line)
log.flush()
print(line, end="", flush=True)
process.wait()
with self.lock:
job["returncode"] = process.returncode
if job["status"] != "cancelled":
job["status"] = "success" if process.returncode == 0 else "failed"
job["finished_at"] = utc_timestamp()
job["process"] = None
except Exception as exc:
with log_path.open("a", encoding="utf-8", errors="replace") as log:
traceback.print_exc(file=log)
with self.lock:
job["status"] = "failed"
job["finished_at"] = utc_timestamp()
job["error"] = f"{type(exc).__name__}: {exc}"
job["process"] = None
def _public_job(self, job: dict[str, Any]) -> dict[str, Any]:
public = {key: value for key, value in job.items() if key != "process"}
return public
def make_handler(state: WorkerState, token: str):
class Handler(BaseHTTPRequestHandler):
server_version = "AniFileBERTColabWorker/1.0"
def log_message(self, fmt: str, *args: Any) -> None:
print(f"[{utc_timestamp()}] {self.address_string()} {fmt % args}", flush=True)
def do_GET(self) -> None:
self._handle("GET")
def do_POST(self) -> None:
self._handle("POST")
def _handle(self, method: str) -> None:
parsed = urlparse(self.path)
path = parsed.path.rstrip("/") or "/"
parts = [part for part in path.split("/") if part]
try:
if not self._authorized():
self._send({"error": "unauthorized"}, HTTPStatus.UNAUTHORIZED)
return
if method == "GET" and path == "/health":
self._send(
{
"ok": True,
"repo_dir": os.fspath(state.repo_dir),
"jobs_dir": os.fspath(state.jobs_dir),
"active_job": state.active_job()["job_id"] if state.active_job() else None,
}
)
return
if method == "GET" and path == "/jobs":
self._send({"jobs": state.list_jobs()})
return
if method == "POST" and path == "/jobs":
payload = self._read_json()
job = state.start_job(payload)
self._send(job, HTTPStatus.ACCEPTED)
return
if len(parts) >= 2 and parts[0] == "jobs":
job_id = parts[1]
if method == "GET" and len(parts) == 2:
job = state.get_job(job_id)
if job is None:
self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND)
else:
self._send(job)
return
if method == "GET" and len(parts) == 3 and parts[2] == "logs":
query = parse_qs(parsed.query)
tail = int(query.get("tail", ["200"])[0])
job = state.get_job_internal(job_id)
if job is None:
self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND)
else:
self._send({"job_id": job_id, "log": read_tail(Path(job["log_path"]), tail)})
return
if method == "GET" and len(parts) == 3 and parts[2] == "manifest":
job = state.get_job_internal(job_id)
if job is None:
self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND)
else:
manifest = self._find_manifest(job)
if manifest is None:
self._send({"error": "manifest not found"}, HTTPStatus.NOT_FOUND)
else:
self._send(json.loads(manifest.read_text(encoding="utf-8")))
return
if method == "POST" and len(parts) == 3 and parts[2] == "cancel":
try:
self._send(state.cancel_job(job_id))
except KeyError:
self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND)
return
self._send({"error": "not found"}, HTTPStatus.NOT_FOUND)
except Exception as exc:
traceback.print_exc()
self._send({"error": f"{type(exc).__name__}: {exc}"}, HTTPStatus.INTERNAL_SERVER_ERROR)
def _authorized(self) -> bool:
header = self.headers.get("Authorization", "")
if header == f"Bearer {token}":
return True
return self.headers.get("X-Colab-Token") == token
def _read_json(self) -> dict[str, Any]:
length = int(self.headers.get("Content-Length", "0"))
if length == 0:
return {}
raw = self.rfile.read(length)
return json.loads(raw.decode("utf-8"))
def _find_manifest(self, job: dict[str, Any]) -> Path | None:
config_path = job.get("config_path")
if config_path and Path(config_path).is_file():
config = json.loads(Path(config_path).read_text(encoding="utf-8"))
training = config.get("training", {})
save_dir = training.get("save_dir")
if save_dir:
manifest = Path(save_dir) / "colab_run_manifest.json"
if manifest.is_file():
return manifest
job_manifest = Path(job["job_dir"]) / "colab_run_manifest.json"
return job_manifest if job_manifest.is_file() else None
def _send(self, data: Any, status: HTTPStatus = HTTPStatus.OK) -> None:
raw = json_dumps(data).encode("utf-8")
self.send_response(status.value)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(raw)))
self.end_headers()
self.wfile.write(raw)
return Handler
def start_tunnel(port: int, binary_path: Path) -> subprocess.Popen[str]:
cloudflared = download_cloudflared(binary_path)
cmd = [
os.fspath(cloudflared),
"tunnel",
"--url",
f"http://127.0.0.1:{port}",
"--no-autoupdate",
]
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
encoding="utf-8",
errors="replace",
bufsize=1,
)
def pump() -> None:
assert proc.stdout is not None
for line in proc.stdout:
print(line, end="", flush=True)
match = TUNNEL_URL_RE.search(line)
if match:
print("\nCOLAB_WORKER_URL=" + match.group(0), flush=True)
threading.Thread(target=pump, daemon=True).start()
return proc
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Start the AniFileBERT Colab worker")
parser.add_argument("--host", default="127.0.0.1", help="HTTP bind host")
parser.add_argument("--port", type=int, default=7860, help="HTTP bind port")
parser.add_argument("--repo-dir", default="/content/AniFileBERT", help="AniFileBERT checkout path in Colab")
parser.add_argument("--jobs-dir", default="/content/drive/MyDrive/AniFileBERT/worker/jobs")
parser.add_argument("--token", default=os.environ.get("ANIFILEBERT_COLAB_TOKEN"))
parser.add_argument("--tunnel", choices=["cloudflare", "none"], default="cloudflare")
parser.add_argument("--cloudflared-path", default="/tmp/anifilebert-cloudflared")
return parser.parse_args()
def main() -> None:
args = parse_args()
token = args.token or secrets.token_urlsafe(24)
repo_dir = Path(args.repo_dir)
if not repo_dir.is_dir():
raise RuntimeError(f"Repo directory does not exist: {repo_dir}")
state = WorkerState(repo_dir=repo_dir, jobs_dir=Path(args.jobs_dir))
server = ThreadingHTTPServer((args.host, args.port), make_handler(state, token))
tunnel_proc: subprocess.Popen[str] | None = None
print("=" * 72)
print("AniFileBERT Colab worker is starting")
print(f"Local URL: http://{args.host}:{args.port}")
print(f"COLAB_WORKER_TOKEN={token}")
print("Keep this Colab cell running while Codex uses the worker.")
print("=" * 72, flush=True)
if args.tunnel == "cloudflare":
tunnel_proc = start_tunnel(args.port, Path(args.cloudflared_path))
else:
print("Tunnel disabled. Use the local URL from inside the Colab runtime.", flush=True)
try:
server.serve_forever()
finally:
server.server_close()
if tunnel_proc and tunnel_proc.poll() is None:
tunnel_proc.terminate()
if __name__ == "__main__":
main()