Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import subprocess | |
| import sys | |
| import time | |
| import urllib.error | |
| import urllib.request | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| DEFAULT_REMOTE = "space" | |
| DEFAULT_REMOTE_BRANCH = "main" | |
| DEFAULT_POLL_INTERVAL_SECONDS = 10 | |
| DEFAULT_DEPLOY_TIMEOUT_SECONDS = 60 * 20 | |
| DEFAULT_TRAIN_TIMEOUT_SECONDS = 60 * 30 | |
| DEFAULT_REQUIRED_HEALTHY_CHECKS = 3 | |
| DEFAULT_MIN_DEPLOY_WAIT_SECONDS = 30 | |
| DEFAULT_COMMIT_MESSAGE = "Smoke-train deployment update" | |
| class ScriptError(RuntimeError): | |
| pass | |
| def print_info(message: str) -> None: | |
| print(f"[info] {message}", flush=True) | |
| def print_warn(message: str) -> None: | |
| print(f"[warn] {message}", flush=True) | |
| def print_error(message: str) -> None: | |
| print(f"[error] {message}", file=sys.stderr, flush=True) | |
| def pretty_json(payload: dict[str, Any]) -> str: | |
| return json.dumps(payload, indent=2, sort_keys=True) | |
| def format_duration(seconds: float | None) -> str: | |
| if seconds is None: | |
| return "n/a" | |
| total_seconds = max(float(seconds), 0.0) | |
| hours = int(total_seconds // 3600) | |
| minutes = int((total_seconds % 3600) // 60) | |
| secs = round(total_seconds % 60, 1) | |
| if hours > 0: | |
| return f"{hours}h {minutes}m {secs}s" | |
| if minutes > 0: | |
| return f"{minutes}m {secs}s" | |
| return f"{secs}s" | |
| def api_url(base_url: str, path: str) -> str: | |
| return f"{base_url.rstrip('/')}/{path.lstrip('/')}" | |
| def http_json( | |
| method: str, | |
| url: str, | |
| *, | |
| payload: dict[str, Any] | None = None, | |
| timeout_seconds: int = 60, | |
| ) -> dict[str, Any]: | |
| data = None | |
| headers = {"Accept": "application/json"} | |
| if payload is not None: | |
| data = json.dumps(payload).encode("utf-8") | |
| headers["Content-Type"] = "application/json" | |
| request = urllib.request.Request(url=url, data=data, headers=headers, method=method.upper()) | |
| try: | |
| with urllib.request.urlopen(request, timeout=timeout_seconds) as response: | |
| raw = response.read().decode("utf-8") | |
| if not raw.strip(): | |
| return {} | |
| return json.loads(raw) | |
| except urllib.error.HTTPError as exc: | |
| body = exc.read().decode("utf-8", errors="replace") | |
| raise ScriptError(f"HTTP {exc.code} for {url}: {body}") from exc | |
| except urllib.error.URLError as exc: | |
| raise ScriptError(f"Request to {url} failed: {exc}") from exc | |
| class GitOptions: | |
| repo_root: Path | |
| remote: str | |
| remote_branch: str | |
| commit_message: str | |
| skip_commit: bool | |
| skip_push: bool | |
| def run_command(command: list[str], cwd: Path) -> subprocess.CompletedProcess[str]: | |
| return subprocess.run( | |
| command, | |
| cwd=str(cwd), | |
| text=True, | |
| capture_output=True, | |
| check=False, | |
| ) | |
| def ensure_success(result: subprocess.CompletedProcess[str], action: str) -> None: | |
| if result.returncode == 0: | |
| return | |
| message = result.stderr.strip() or result.stdout.strip() or f"{action} failed with exit code {result.returncode}" | |
| raise ScriptError(f"{action} failed: {message}") | |
| def has_uncommitted_changes(repo_root: Path) -> bool: | |
| result = run_command(["git", "status", "--porcelain"], repo_root) | |
| ensure_success(result, "git status") | |
| return bool(result.stdout.strip()) | |
| def current_head_sha(repo_root: Path) -> str: | |
| result = run_command(["git", "rev-parse", "HEAD"], repo_root) | |
| ensure_success(result, "git rev-parse HEAD") | |
| return result.stdout.strip() | |
| def remote_branch_sha(repo_root: Path, remote: str, remote_branch: str) -> str | None: | |
| result = run_command(["git", "ls-remote", remote, f"refs/heads/{remote_branch}"], repo_root) | |
| ensure_success(result, f"git ls-remote {remote} refs/heads/{remote_branch}") | |
| line = result.stdout.strip() | |
| if not line: | |
| return None | |
| return line.split()[0] | |
| def commit_changes_if_needed(options: GitOptions) -> str: | |
| repo_root = options.repo_root | |
| head_before = current_head_sha(repo_root) | |
| if not has_uncommitted_changes(repo_root): | |
| print_info(f"Working tree is clean at {head_before}.") | |
| return head_before | |
| if options.skip_commit: | |
| raise ScriptError( | |
| "Working tree has uncommitted changes, but --skip-commit was set. " | |
| "Commit the changes manually or remove --skip-commit." | |
| ) | |
| print_info("Staging and committing local changes before deploy.") | |
| ensure_success(run_command(["git", "add", "-A"], repo_root), "git add -A") | |
| commit_result = run_command(["git", "commit", "-m", options.commit_message], repo_root) | |
| ensure_success(commit_result, "git commit") | |
| if commit_result.stdout.strip(): | |
| print(commit_result.stdout.strip(), flush=True) | |
| head_after = current_head_sha(repo_root) | |
| print_info(f"Created commit {head_after}.") | |
| return head_after | |
| def deployment_needed(options: GitOptions, local_head_sha: str) -> bool: | |
| remote_sha = remote_branch_sha(options.repo_root, options.remote, options.remote_branch) | |
| if remote_sha is None: | |
| print_info(f"Remote branch {options.remote}/{options.remote_branch} does not exist yet.") | |
| return True | |
| if remote_sha == local_head_sha: | |
| print_info(f"Remote {options.remote}/{options.remote_branch} already points to {local_head_sha}.") | |
| return False | |
| print_info( | |
| f"Remote {options.remote}/{options.remote_branch} is at {remote_sha}; " | |
| f"local HEAD is {local_head_sha}." | |
| ) | |
| return True | |
| def push_current_head(options: GitOptions) -> None: | |
| print_info(f"Pushing HEAD to {options.remote}/{options.remote_branch}.") | |
| result = run_command(["git", "push", options.remote, f"HEAD:{options.remote_branch}"], options.repo_root) | |
| ensure_success(result, f"git push {options.remote} HEAD:{options.remote_branch}") | |
| summary = result.stdout.strip() or result.stderr.strip() | |
| if summary: | |
| print(summary, flush=True) | |
| def fetch_health(base_url: str) -> dict[str, Any]: | |
| return http_json("GET", api_url(base_url, "/health")) | |
| def wait_for_health( | |
| base_url: str, | |
| *, | |
| timeout_seconds: int, | |
| poll_interval_seconds: int, | |
| required_healthy_checks: int, | |
| min_deploy_wait_seconds: int, | |
| ) -> dict[str, Any]: | |
| deadline = time.time() + timeout_seconds | |
| push_started_at = time.time() | |
| prior_payload: dict[str, Any] | None = None | |
| transition_observed = False | |
| healthy_streak = 0 | |
| while time.time() < deadline: | |
| try: | |
| payload = fetch_health(base_url) | |
| except ScriptError as exc: | |
| healthy_streak = 0 | |
| transition_observed = True | |
| print_info(f"Waiting for service health: {exc}") | |
| time.sleep(poll_interval_seconds) | |
| continue | |
| if prior_payload is not None and payload != prior_payload: | |
| transition_observed = True | |
| elapsed = time.time() - push_started_at | |
| if elapsed < min_deploy_wait_seconds: | |
| remaining = max(0, int(min_deploy_wait_seconds - elapsed)) | |
| print_info(f"Health endpoint reachable. Waiting {remaining}s for deployment stabilization.") | |
| prior_payload = payload | |
| time.sleep(poll_interval_seconds) | |
| continue | |
| if payload.get("status") == "healthy": | |
| healthy_streak += 1 | |
| print_info( | |
| f"Health check {healthy_streak}/{required_healthy_checks}: " | |
| f"training={payload.get('training')} model_loaded={payload.get('model_loaded')}" | |
| ) | |
| if healthy_streak >= required_healthy_checks: | |
| if not transition_observed and prior_payload is not None: | |
| print_warn("No payload transition was observed after deploy; continuing because health stabilized.") | |
| return payload | |
| else: | |
| healthy_streak = 0 | |
| print_info(f"Health payload not ready yet: {pretty_json(payload)}") | |
| prior_payload = payload | |
| time.sleep(poll_interval_seconds) | |
| raise ScriptError(f"Service did not report healthy within {timeout_seconds} seconds.") | |
| def fetch_training_status(base_url: str) -> dict[str, Any]: | |
| return http_json("GET", api_url(base_url, "/train/status")) | |
| def summarize_training_status(payload: dict[str, Any]) -> str: | |
| status = payload.get("status") | |
| phase = payload.get("phase") | |
| completed = payload.get("completed_steps") | |
| total = payload.get("total_steps") | |
| difficulty = payload.get("current_difficulty") | |
| problem_family = payload.get("last_problem_family") | |
| reward = payload.get("last_reward") | |
| elapsed_minutes = payload.get("elapsed_minutes") | |
| pieces = [ | |
| f"status={status}", | |
| f"phase={phase}", | |
| f"steps={completed}/{total}", | |
| ] | |
| if difficulty: | |
| pieces.append(f"difficulty={difficulty}") | |
| if problem_family: | |
| pieces.append(f"family={problem_family}") | |
| if reward is not None: | |
| pieces.append(f"reward={reward}") | |
| if elapsed_minutes is not None: | |
| pieces.append(f"elapsed={elapsed_minutes}m") | |
| return " ".join(pieces) | |
| def print_runtime_baseline(payload: dict[str, Any]) -> None: | |
| timing = payload.get("timing_summary") or {} | |
| wall_clock_seconds = timing.get("wall_clock_seconds") | |
| if wall_clock_seconds is None: | |
| wall_clock_seconds = payload.get("elapsed_seconds") | |
| if wall_clock_seconds is None: | |
| return | |
| avg_seconds_per_step = timing.get("avg_seconds_per_step") | |
| avg_seconds_per_episode = timing.get("avg_seconds_per_episode") | |
| steps_per_hour = timing.get("steps_per_hour") | |
| episodes_per_hour = timing.get("episodes_per_hour") | |
| print_info( | |
| "Smoke runtime baseline: " | |
| f"wall_clock={format_duration(float(wall_clock_seconds))}, " | |
| f"avg_step={avg_seconds_per_step}s, " | |
| f"avg_episode={avg_seconds_per_episode}s" | |
| ) | |
| if steps_per_hour is not None or episodes_per_hour is not None: | |
| six_hour_steps = round(float(steps_per_hour) * 6) if steps_per_hour is not None else None | |
| seven_hour_steps = round(float(steps_per_hour) * 7) if steps_per_hour is not None else None | |
| six_hour_episodes = round(float(episodes_per_hour) * 6) if episodes_per_hour is not None else None | |
| seven_hour_episodes = round(float(episodes_per_hour) * 7) if episodes_per_hour is not None else None | |
| print_info( | |
| "Window estimate at current smoke throughput: " | |
| f"6h -> steps={six_hour_steps}, episodes={six_hour_episodes}; " | |
| f"7h -> steps={seven_hour_steps}, episodes={seven_hour_episodes}" | |
| ) | |
| def ensure_no_active_training(base_url: str) -> None: | |
| payload = fetch_training_status(base_url) | |
| if payload.get("status") == "running": | |
| raise ScriptError( | |
| "The remote training manager already reports an active run. " | |
| f"Current status: {summarize_training_status(payload)}" | |
| ) | |
| def start_training(base_url: str, train_payload: dict[str, Any]) -> dict[str, Any]: | |
| print_info(f"Starting training with payload: {pretty_json(train_payload)}") | |
| payload = http_json("POST", api_url(base_url, "/train"), payload=train_payload) | |
| print_info(f"Training accepted: {pretty_json(payload)}") | |
| return payload | |
| def poll_training_status( | |
| base_url: str, | |
| *, | |
| poll_interval_seconds: int, | |
| timeout_seconds: int, | |
| ) -> dict[str, Any]: | |
| deadline = time.time() + timeout_seconds | |
| last_signature: tuple[Any, ...] | None = None | |
| while time.time() < deadline: | |
| payload = fetch_training_status(base_url) | |
| signature = ( | |
| payload.get("status"), | |
| payload.get("phase"), | |
| payload.get("completed_steps"), | |
| payload.get("total_steps"), | |
| payload.get("last_problem_family"), | |
| payload.get("last_reward"), | |
| payload.get("error"), | |
| ) | |
| if signature != last_signature: | |
| print_info(summarize_training_status(payload)) | |
| last_signature = signature | |
| status = payload.get("status") | |
| if status == "failed": | |
| print_error("Training failed. Final payload follows.") | |
| print(pretty_json(payload), flush=True) | |
| return payload | |
| if status == "succeeded": | |
| print_info("Training succeeded. Final payload follows.") | |
| print(pretty_json(payload), flush=True) | |
| print_runtime_baseline(payload) | |
| return payload | |
| time.sleep(poll_interval_seconds) | |
| raise ScriptError(f"Training did not finish within {timeout_seconds} seconds.") | |
| def build_train_payload(args: argparse.Namespace) -> dict[str, Any]: | |
| payload: dict[str, Any] = {"preset": args.preset} | |
| if args.train_payload_json: | |
| extra_payload = json.loads(args.train_payload_json) | |
| if not isinstance(extra_payload, dict): | |
| raise ScriptError("--train-payload-json must decode to a JSON object.") | |
| payload.update(extra_payload) | |
| return payload | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser( | |
| description="Deploy the current repo to a Hugging Face Space and run a smoke training job.", | |
| ) | |
| parser.add_argument( | |
| "--base-url", | |
| required=True, | |
| help="Base URL of the running server, for example https://<space>.hf.space or http://localhost:7860", | |
| ) | |
| parser.add_argument( | |
| "--preset", | |
| default="smoke", | |
| help="Training preset to send to /train. Defaults to smoke.", | |
| ) | |
| parser.add_argument( | |
| "--train-payload-json", | |
| default=None, | |
| help="Optional JSON object merged into the /train request body.", | |
| ) | |
| parser.add_argument( | |
| "--poll-interval-seconds", | |
| type=int, | |
| default=DEFAULT_POLL_INTERVAL_SECONDS, | |
| help="How often to poll /health and /train/status.", | |
| ) | |
| parser.add_argument( | |
| "--deploy-timeout-seconds", | |
| type=int, | |
| default=DEFAULT_DEPLOY_TIMEOUT_SECONDS, | |
| help="Maximum time to wait for the service to become healthy after push.", | |
| ) | |
| parser.add_argument( | |
| "--train-timeout-seconds", | |
| type=int, | |
| default=DEFAULT_TRAIN_TIMEOUT_SECONDS, | |
| help="Maximum time to wait for the smoke train run to finish.", | |
| ) | |
| parser.add_argument( | |
| "--required-healthy-checks", | |
| type=int, | |
| default=DEFAULT_REQUIRED_HEALTHY_CHECKS, | |
| help="Number of consecutive healthy /health checks required before training starts.", | |
| ) | |
| parser.add_argument( | |
| "--min-deploy-wait-seconds", | |
| type=int, | |
| default=DEFAULT_MIN_DEPLOY_WAIT_SECONDS, | |
| help="Minimum time to wait after push before treating health as stable.", | |
| ) | |
| parser.add_argument( | |
| "--remote", | |
| default=DEFAULT_REMOTE, | |
| help="Git remote to push to when deploying.", | |
| ) | |
| parser.add_argument( | |
| "--remote-branch", | |
| default=DEFAULT_REMOTE_BRANCH, | |
| help="Remote branch to push HEAD to when deploying.", | |
| ) | |
| parser.add_argument( | |
| "--commit-message", | |
| default=DEFAULT_COMMIT_MESSAGE, | |
| help="Commit message to use if local changes need to be committed before push.", | |
| ) | |
| parser.add_argument( | |
| "--skip-commit", | |
| action="store_true", | |
| help="Do not auto-commit local changes before deploy.", | |
| ) | |
| parser.add_argument( | |
| "--skip-push", | |
| action="store_true", | |
| help="Skip git deploy entirely and just hit the running server.", | |
| ) | |
| parser.add_argument( | |
| "--skip-health-check", | |
| action="store_true", | |
| help="Skip waiting on /health before training.", | |
| ) | |
| parser.add_argument( | |
| "--trigger-only", | |
| action="store_true", | |
| help="Start the smoke run and exit without polling to completion.", | |
| ) | |
| parser.add_argument( | |
| "--status-only", | |
| action="store_true", | |
| help="Do not start a new run; just print /train/status and optionally poll it.", | |
| ) | |
| parser.add_argument( | |
| "--follow-running", | |
| action="store_true", | |
| help="If /train/status already reports a running job, follow it instead of failing.", | |
| ) | |
| return parser | |
| def maybe_deploy(args: argparse.Namespace, repo_root: Path) -> None: | |
| if args.skip_push: | |
| print_info("Skipping git deploy because --skip-push was set.") | |
| return | |
| git_options = GitOptions( | |
| repo_root=repo_root, | |
| remote=args.remote, | |
| remote_branch=args.remote_branch, | |
| commit_message=args.commit_message, | |
| skip_commit=args.skip_commit, | |
| skip_push=args.skip_push, | |
| ) | |
| local_head_sha = commit_changes_if_needed(git_options) | |
| if not deployment_needed(git_options, local_head_sha): | |
| print_info("Skipping push because the remote is already on the current local HEAD.") | |
| return | |
| push_current_head(git_options) | |
| def main(argv: list[str] | None = None) -> int: | |
| args = build_parser().parse_args(argv) | |
| repo_root = Path(__file__).resolve().parents[1] | |
| try: | |
| if args.status_only: | |
| payload = fetch_training_status(args.base_url) | |
| print(pretty_json(payload), flush=True) | |
| if args.follow_running and payload.get("status") == "running": | |
| final_status = poll_training_status( | |
| args.base_url, | |
| poll_interval_seconds=args.poll_interval_seconds, | |
| timeout_seconds=args.train_timeout_seconds, | |
| ) | |
| return 0 if final_status.get("status") == "succeeded" else 1 | |
| return 0 | |
| maybe_deploy(args, repo_root) | |
| if not args.skip_health_check: | |
| wait_for_health( | |
| args.base_url, | |
| timeout_seconds=args.deploy_timeout_seconds, | |
| poll_interval_seconds=args.poll_interval_seconds, | |
| required_healthy_checks=args.required_healthy_checks, | |
| min_deploy_wait_seconds=args.min_deploy_wait_seconds, | |
| ) | |
| else: | |
| print_info("Skipping health wait because --skip-health-check was set.") | |
| current_status = fetch_training_status(args.base_url) | |
| if current_status.get("status") == "running": | |
| if args.follow_running: | |
| print_warn( | |
| "A training job is already running; following the existing run instead of starting a new one." | |
| ) | |
| final_status = poll_training_status( | |
| args.base_url, | |
| poll_interval_seconds=args.poll_interval_seconds, | |
| timeout_seconds=args.train_timeout_seconds, | |
| ) | |
| return 0 if final_status.get("status") == "succeeded" else 1 | |
| ensure_no_active_training(args.base_url) | |
| train_payload = build_train_payload(args) | |
| start_training(args.base_url, train_payload) | |
| if args.trigger_only: | |
| print_info("Training was triggered successfully; exiting because --trigger-only was set.") | |
| return 0 | |
| final_status = poll_training_status( | |
| args.base_url, | |
| poll_interval_seconds=args.poll_interval_seconds, | |
| timeout_seconds=args.train_timeout_seconds, | |
| ) | |
| return 0 if final_status.get("status") == "succeeded" else 1 | |
| except Exception as exc: | |
| print_error(str(exc)) | |
| return 1 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |