# -*- coding: utf-8 -*- """Small HTTP worker for running AniFileBERT training jobs on Google Colab. Start this inside a Colab runtime: python colab_worker.py The worker exposes a token-protected local HTTP API and, by default, starts a Cloudflare Quick Tunnel so Codex on your local machine can submit jobs. """ from __future__ import annotations import argparse import json import os from pathlib import Path import platform import re import secrets import shutil import signal import subprocess import sys import threading import time import traceback from http import HTTPStatus from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from typing import Any from urllib.parse import parse_qs, urlparse import urllib.request TERMINAL_STATES = {"success", "failed", "cancelled"} TUNNEL_URL_RE = re.compile(r"https://[-a-zA-Z0-9.]+\.trycloudflare\.com") def utc_timestamp() -> str: return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) def json_dumps(data: Any) -> str: return json.dumps(data, ensure_ascii=False, indent=2) def read_tail(path: Path, lines: int) -> str: if not path.is_file(): return "" if lines <= 0: return path.read_text(encoding="utf-8", errors="replace") chunk_size = 8192 data = b"" with path.open("rb") as f: f.seek(0, os.SEEK_END) pos = f.tell() while pos > 0 and data.count(b"\n") <= lines: read_size = min(chunk_size, pos) pos -= read_size f.seek(pos) data = f.read(read_size) + data return b"\n".join(data.splitlines()[-lines:]).decode("utf-8", errors="replace") def download_cloudflared(path: Path) -> Path: if path.is_file(): return path existing = shutil.which("cloudflared") if existing: return Path(existing) arch = platform.machine().lower() if arch in {"x86_64", "amd64"}: suffix = "linux-amd64" elif arch in {"aarch64", "arm64"}: suffix = "linux-arm64" else: raise RuntimeError(f"Unsupported CPU architecture for cloudflared: {arch}") url = f"https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-{suffix}" print(f"Downloading cloudflared: {url}", flush=True) path.parent.mkdir(parents=True, exist_ok=True) urllib.request.urlretrieve(url, path) path.chmod(0o755) return path class WorkerState: def __init__(self, repo_dir: Path, jobs_dir: Path): self.repo_dir = repo_dir self.jobs_dir = jobs_dir self.jobs_dir.mkdir(parents=True, exist_ok=True) self.jobs: dict[str, dict[str, Any]] = {} self.lock = threading.RLock() def list_jobs(self) -> list[dict[str, Any]]: with self.lock: return [self._public_job(job) for job in self.jobs.values()] def get_job(self, job_id: str) -> dict[str, Any] | None: with self.lock: job = self.jobs.get(job_id) return self._public_job(job) if job else None def get_job_internal(self, job_id: str) -> dict[str, Any] | None: with self.lock: return self.jobs.get(job_id) def active_job(self) -> dict[str, Any] | None: with self.lock: for job in self.jobs.values(): if job["status"] not in TERMINAL_STATES: return job return None def start_job(self, payload: dict[str, Any]) -> dict[str, Any]: with self.lock: active = self.active_job() if active is not None: raise RuntimeError(f"Job already running: {active['job_id']}") job_id = time.strftime("%Y%m%d-%H%M%S", time.gmtime()) + "-" + secrets.token_hex(3) job_dir = self.jobs_dir / job_id job_dir.mkdir(parents=True, exist_ok=True) log_path = job_dir / "worker.log" config_path: Path | None = None cmd = [sys.executable, "colab_train.py"] config = self._job_config(payload) config.setdefault("artifacts", {}) config["artifacts"]["manifest"] = os.fspath(job_dir / "colab_run_manifest.json") config_path = job_dir / "config.json" config_path.write_text(json_dumps(config), encoding="utf-8") cmd.extend(["--config", os.fspath(config_path)]) for arg in payload.get("args", []): cmd.append(str(arg)) job = { "job_id": job_id, "status": "queued", "created_at": utc_timestamp(), "started_at": None, "finished_at": None, "returncode": None, "cmd": cmd, "cwd": os.fspath(self.repo_dir), "job_dir": os.fspath(job_dir), "log_path": os.fspath(log_path), "config_path": os.fspath(config_path) if config_path else None, "error": None, "process": None, } self.jobs[job_id] = job thread = threading.Thread(target=self._run_job, args=(job_id,), daemon=True) thread.start() return self._public_job(job) def _job_config(self, payload: dict[str, Any]) -> dict[str, Any]: if "config" in payload: return json.loads(json.dumps(payload["config"], ensure_ascii=False)) profile = str(payload.get("profile", "dmhy_regex_finetune")) profile_path = self.repo_dir / "colab" / "configs" / f"{profile}.json" if not profile_path.is_file(): raise FileNotFoundError(f"Profile not found: {profile_path}") return json.loads(profile_path.read_text(encoding="utf-8")) def cancel_job(self, job_id: str) -> dict[str, Any]: with self.lock: job = self.jobs.get(job_id) if job is None: raise KeyError(job_id) process: subprocess.Popen[str] | None = job.get("process") if job["status"] in TERMINAL_STATES: return self._public_job(job) job["status"] = "cancelled" job["finished_at"] = utc_timestamp() if process and process.poll() is None: try: os.killpg(os.getpgid(process.pid), signal.SIGTERM) except Exception: process.terminate() return self.get_job(job_id) or {} def _run_job(self, job_id: str) -> None: job = self.get_job_internal(job_id) if job is None: return log_path = Path(job["log_path"]) try: with self.lock: job["status"] = "running" job["started_at"] = utc_timestamp() with log_path.open("w", encoding="utf-8", errors="replace") as log: log.write(f"job_id={job_id}\n") log.write(f"cwd={job['cwd']}\n") log.write("$ " + " ".join(job["cmd"]) + "\n\n") log.flush() process = subprocess.Popen( job["cmd"], cwd=job["cwd"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, encoding="utf-8", errors="replace", bufsize=1, preexec_fn=os.setsid if hasattr(os, "setsid") else None, ) with self.lock: job["process"] = process assert process.stdout is not None for line in process.stdout: log.write(line) log.flush() print(line, end="", flush=True) process.wait() with self.lock: job["returncode"] = process.returncode if job["status"] != "cancelled": job["status"] = "success" if process.returncode == 0 else "failed" job["finished_at"] = utc_timestamp() job["process"] = None except Exception as exc: with log_path.open("a", encoding="utf-8", errors="replace") as log: traceback.print_exc(file=log) with self.lock: job["status"] = "failed" job["finished_at"] = utc_timestamp() job["error"] = f"{type(exc).__name__}: {exc}" job["process"] = None def _public_job(self, job: dict[str, Any]) -> dict[str, Any]: public = {key: value for key, value in job.items() if key != "process"} return public def make_handler(state: WorkerState, token: str): class Handler(BaseHTTPRequestHandler): server_version = "AniFileBERTColabWorker/1.0" def log_message(self, fmt: str, *args: Any) -> None: print(f"[{utc_timestamp()}] {self.address_string()} {fmt % args}", flush=True) def do_GET(self) -> None: self._handle("GET") def do_POST(self) -> None: self._handle("POST") def _handle(self, method: str) -> None: parsed = urlparse(self.path) path = parsed.path.rstrip("/") or "/" parts = [part for part in path.split("/") if part] try: if not self._authorized(): self._send({"error": "unauthorized"}, HTTPStatus.UNAUTHORIZED) return if method == "GET" and path == "/health": self._send( { "ok": True, "repo_dir": os.fspath(state.repo_dir), "jobs_dir": os.fspath(state.jobs_dir), "active_job": state.active_job()["job_id"] if state.active_job() else None, } ) return if method == "GET" and path == "/jobs": self._send({"jobs": state.list_jobs()}) return if method == "POST" and path == "/jobs": payload = self._read_json() job = state.start_job(payload) self._send(job, HTTPStatus.ACCEPTED) return if len(parts) >= 2 and parts[0] == "jobs": job_id = parts[1] if method == "GET" and len(parts) == 2: job = state.get_job(job_id) if job is None: self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND) else: self._send(job) return if method == "GET" and len(parts) == 3 and parts[2] == "logs": query = parse_qs(parsed.query) tail = int(query.get("tail", ["200"])[0]) job = state.get_job_internal(job_id) if job is None: self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND) else: self._send({"job_id": job_id, "log": read_tail(Path(job["log_path"]), tail)}) return if method == "GET" and len(parts) == 3 and parts[2] == "manifest": job = state.get_job_internal(job_id) if job is None: self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND) else: manifest = self._find_manifest(job) if manifest is None: self._send({"error": "manifest not found"}, HTTPStatus.NOT_FOUND) else: self._send(json.loads(manifest.read_text(encoding="utf-8"))) return if method == "POST" and len(parts) == 3 and parts[2] == "cancel": try: self._send(state.cancel_job(job_id)) except KeyError: self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND) return self._send({"error": "not found"}, HTTPStatus.NOT_FOUND) except Exception as exc: traceback.print_exc() self._send({"error": f"{type(exc).__name__}: {exc}"}, HTTPStatus.INTERNAL_SERVER_ERROR) def _authorized(self) -> bool: header = self.headers.get("Authorization", "") if header == f"Bearer {token}": return True return self.headers.get("X-Colab-Token") == token def _read_json(self) -> dict[str, Any]: length = int(self.headers.get("Content-Length", "0")) if length == 0: return {} raw = self.rfile.read(length) return json.loads(raw.decode("utf-8")) def _find_manifest(self, job: dict[str, Any]) -> Path | None: config_path = job.get("config_path") if config_path and Path(config_path).is_file(): config = json.loads(Path(config_path).read_text(encoding="utf-8")) training = config.get("training", {}) save_dir = training.get("save_dir") if save_dir: manifest = Path(save_dir) / "colab_run_manifest.json" if manifest.is_file(): return manifest job_manifest = Path(job["job_dir"]) / "colab_run_manifest.json" return job_manifest if job_manifest.is_file() else None def _send(self, data: Any, status: HTTPStatus = HTTPStatus.OK) -> None: raw = json_dumps(data).encode("utf-8") self.send_response(status.value) self.send_header("Content-Type", "application/json; charset=utf-8") self.send_header("Content-Length", str(len(raw))) self.end_headers() self.wfile.write(raw) return Handler def start_tunnel(port: int, binary_path: Path) -> subprocess.Popen[str]: cloudflared = download_cloudflared(binary_path) cmd = [ os.fspath(cloudflared), "tunnel", "--url", f"http://127.0.0.1:{port}", "--no-autoupdate", ] proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, encoding="utf-8", errors="replace", bufsize=1, ) def pump() -> None: assert proc.stdout is not None for line in proc.stdout: print(line, end="", flush=True) match = TUNNEL_URL_RE.search(line) if match: print("\nCOLAB_WORKER_URL=" + match.group(0), flush=True) threading.Thread(target=pump, daemon=True).start() return proc def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Start the AniFileBERT Colab worker") parser.add_argument("--host", default="127.0.0.1", help="HTTP bind host") parser.add_argument("--port", type=int, default=7860, help="HTTP bind port") parser.add_argument("--repo-dir", default="/content/AniFileBERT", help="AniFileBERT checkout path in Colab") parser.add_argument("--jobs-dir", default="/content/drive/MyDrive/AniFileBERT/worker/jobs") parser.add_argument("--token", default=os.environ.get("ANIFILEBERT_COLAB_TOKEN")) parser.add_argument("--tunnel", choices=["cloudflare", "none"], default="cloudflare") parser.add_argument("--cloudflared-path", default="/tmp/anifilebert-cloudflared") return parser.parse_args() def main() -> None: args = parse_args() token = args.token or secrets.token_urlsafe(24) repo_dir = Path(args.repo_dir) if not repo_dir.is_dir(): raise RuntimeError(f"Repo directory does not exist: {repo_dir}") state = WorkerState(repo_dir=repo_dir, jobs_dir=Path(args.jobs_dir)) server = ThreadingHTTPServer((args.host, args.port), make_handler(state, token)) tunnel_proc: subprocess.Popen[str] | None = None print("=" * 72) print("AniFileBERT Colab worker is starting") print(f"Local URL: http://{args.host}:{args.port}") print(f"COLAB_WORKER_TOKEN={token}") print("Keep this Colab cell running while Codex uses the worker.") print("=" * 72, flush=True) if args.tunnel == "cloudflare": tunnel_proc = start_tunnel(args.port, Path(args.cloudflared_path)) else: print("Tunnel disabled. Use the local URL from inside the Colab runtime.", flush=True) try: server.serve_forever() finally: server.server_close() if tunnel_proc and tunnel_proc.poll() is None: tunnel_proc.terminate() if __name__ == "__main__": main()