meta-rl-dsa-solver / scripts /deploy_and_smoke_train.py
Dishaaa25's picture
Expose training timing summaries in Space status
6c58fca
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import subprocess
import sys
import time
import urllib.error
import urllib.request
from dataclasses import dataclass
from pathlib import Path
from typing import Any
DEFAULT_REMOTE = "space"
DEFAULT_REMOTE_BRANCH = "main"
DEFAULT_POLL_INTERVAL_SECONDS = 10
DEFAULT_DEPLOY_TIMEOUT_SECONDS = 60 * 20
DEFAULT_TRAIN_TIMEOUT_SECONDS = 60 * 30
DEFAULT_REQUIRED_HEALTHY_CHECKS = 3
DEFAULT_MIN_DEPLOY_WAIT_SECONDS = 30
DEFAULT_COMMIT_MESSAGE = "Smoke-train deployment update"
class ScriptError(RuntimeError):
pass
def print_info(message: str) -> None:
print(f"[info] {message}", flush=True)
def print_warn(message: str) -> None:
print(f"[warn] {message}", flush=True)
def print_error(message: str) -> None:
print(f"[error] {message}", file=sys.stderr, flush=True)
def pretty_json(payload: dict[str, Any]) -> str:
return json.dumps(payload, indent=2, sort_keys=True)
def format_duration(seconds: float | None) -> str:
if seconds is None:
return "n/a"
total_seconds = max(float(seconds), 0.0)
hours = int(total_seconds // 3600)
minutes = int((total_seconds % 3600) // 60)
secs = round(total_seconds % 60, 1)
if hours > 0:
return f"{hours}h {minutes}m {secs}s"
if minutes > 0:
return f"{minutes}m {secs}s"
return f"{secs}s"
def api_url(base_url: str, path: str) -> str:
return f"{base_url.rstrip('/')}/{path.lstrip('/')}"
def http_json(
method: str,
url: str,
*,
payload: dict[str, Any] | None = None,
timeout_seconds: int = 60,
) -> dict[str, Any]:
data = None
headers = {"Accept": "application/json"}
if payload is not None:
data = json.dumps(payload).encode("utf-8")
headers["Content-Type"] = "application/json"
request = urllib.request.Request(url=url, data=data, headers=headers, method=method.upper())
try:
with urllib.request.urlopen(request, timeout=timeout_seconds) as response:
raw = response.read().decode("utf-8")
if not raw.strip():
return {}
return json.loads(raw)
except urllib.error.HTTPError as exc:
body = exc.read().decode("utf-8", errors="replace")
raise ScriptError(f"HTTP {exc.code} for {url}: {body}") from exc
except urllib.error.URLError as exc:
raise ScriptError(f"Request to {url} failed: {exc}") from exc
@dataclass(frozen=True)
class GitOptions:
repo_root: Path
remote: str
remote_branch: str
commit_message: str
skip_commit: bool
skip_push: bool
def run_command(command: list[str], cwd: Path) -> subprocess.CompletedProcess[str]:
return subprocess.run(
command,
cwd=str(cwd),
text=True,
capture_output=True,
check=False,
)
def ensure_success(result: subprocess.CompletedProcess[str], action: str) -> None:
if result.returncode == 0:
return
message = result.stderr.strip() or result.stdout.strip() or f"{action} failed with exit code {result.returncode}"
raise ScriptError(f"{action} failed: {message}")
def has_uncommitted_changes(repo_root: Path) -> bool:
result = run_command(["git", "status", "--porcelain"], repo_root)
ensure_success(result, "git status")
return bool(result.stdout.strip())
def current_head_sha(repo_root: Path) -> str:
result = run_command(["git", "rev-parse", "HEAD"], repo_root)
ensure_success(result, "git rev-parse HEAD")
return result.stdout.strip()
def remote_branch_sha(repo_root: Path, remote: str, remote_branch: str) -> str | None:
result = run_command(["git", "ls-remote", remote, f"refs/heads/{remote_branch}"], repo_root)
ensure_success(result, f"git ls-remote {remote} refs/heads/{remote_branch}")
line = result.stdout.strip()
if not line:
return None
return line.split()[0]
def commit_changes_if_needed(options: GitOptions) -> str:
repo_root = options.repo_root
head_before = current_head_sha(repo_root)
if not has_uncommitted_changes(repo_root):
print_info(f"Working tree is clean at {head_before}.")
return head_before
if options.skip_commit:
raise ScriptError(
"Working tree has uncommitted changes, but --skip-commit was set. "
"Commit the changes manually or remove --skip-commit."
)
print_info("Staging and committing local changes before deploy.")
ensure_success(run_command(["git", "add", "-A"], repo_root), "git add -A")
commit_result = run_command(["git", "commit", "-m", options.commit_message], repo_root)
ensure_success(commit_result, "git commit")
if commit_result.stdout.strip():
print(commit_result.stdout.strip(), flush=True)
head_after = current_head_sha(repo_root)
print_info(f"Created commit {head_after}.")
return head_after
def deployment_needed(options: GitOptions, local_head_sha: str) -> bool:
remote_sha = remote_branch_sha(options.repo_root, options.remote, options.remote_branch)
if remote_sha is None:
print_info(f"Remote branch {options.remote}/{options.remote_branch} does not exist yet.")
return True
if remote_sha == local_head_sha:
print_info(f"Remote {options.remote}/{options.remote_branch} already points to {local_head_sha}.")
return False
print_info(
f"Remote {options.remote}/{options.remote_branch} is at {remote_sha}; "
f"local HEAD is {local_head_sha}."
)
return True
def push_current_head(options: GitOptions) -> None:
print_info(f"Pushing HEAD to {options.remote}/{options.remote_branch}.")
result = run_command(["git", "push", options.remote, f"HEAD:{options.remote_branch}"], options.repo_root)
ensure_success(result, f"git push {options.remote} HEAD:{options.remote_branch}")
summary = result.stdout.strip() or result.stderr.strip()
if summary:
print(summary, flush=True)
def fetch_health(base_url: str) -> dict[str, Any]:
return http_json("GET", api_url(base_url, "/health"))
def wait_for_health(
base_url: str,
*,
timeout_seconds: int,
poll_interval_seconds: int,
required_healthy_checks: int,
min_deploy_wait_seconds: int,
) -> dict[str, Any]:
deadline = time.time() + timeout_seconds
push_started_at = time.time()
prior_payload: dict[str, Any] | None = None
transition_observed = False
healthy_streak = 0
while time.time() < deadline:
try:
payload = fetch_health(base_url)
except ScriptError as exc:
healthy_streak = 0
transition_observed = True
print_info(f"Waiting for service health: {exc}")
time.sleep(poll_interval_seconds)
continue
if prior_payload is not None and payload != prior_payload:
transition_observed = True
elapsed = time.time() - push_started_at
if elapsed < min_deploy_wait_seconds:
remaining = max(0, int(min_deploy_wait_seconds - elapsed))
print_info(f"Health endpoint reachable. Waiting {remaining}s for deployment stabilization.")
prior_payload = payload
time.sleep(poll_interval_seconds)
continue
if payload.get("status") == "healthy":
healthy_streak += 1
print_info(
f"Health check {healthy_streak}/{required_healthy_checks}: "
f"training={payload.get('training')} model_loaded={payload.get('model_loaded')}"
)
if healthy_streak >= required_healthy_checks:
if not transition_observed and prior_payload is not None:
print_warn("No payload transition was observed after deploy; continuing because health stabilized.")
return payload
else:
healthy_streak = 0
print_info(f"Health payload not ready yet: {pretty_json(payload)}")
prior_payload = payload
time.sleep(poll_interval_seconds)
raise ScriptError(f"Service did not report healthy within {timeout_seconds} seconds.")
def fetch_training_status(base_url: str) -> dict[str, Any]:
return http_json("GET", api_url(base_url, "/train/status"))
def summarize_training_status(payload: dict[str, Any]) -> str:
status = payload.get("status")
phase = payload.get("phase")
completed = payload.get("completed_steps")
total = payload.get("total_steps")
difficulty = payload.get("current_difficulty")
problem_family = payload.get("last_problem_family")
reward = payload.get("last_reward")
elapsed_minutes = payload.get("elapsed_minutes")
pieces = [
f"status={status}",
f"phase={phase}",
f"steps={completed}/{total}",
]
if difficulty:
pieces.append(f"difficulty={difficulty}")
if problem_family:
pieces.append(f"family={problem_family}")
if reward is not None:
pieces.append(f"reward={reward}")
if elapsed_minutes is not None:
pieces.append(f"elapsed={elapsed_minutes}m")
return " ".join(pieces)
def print_runtime_baseline(payload: dict[str, Any]) -> None:
timing = payload.get("timing_summary") or {}
wall_clock_seconds = timing.get("wall_clock_seconds")
if wall_clock_seconds is None:
wall_clock_seconds = payload.get("elapsed_seconds")
if wall_clock_seconds is None:
return
avg_seconds_per_step = timing.get("avg_seconds_per_step")
avg_seconds_per_episode = timing.get("avg_seconds_per_episode")
steps_per_hour = timing.get("steps_per_hour")
episodes_per_hour = timing.get("episodes_per_hour")
print_info(
"Smoke runtime baseline: "
f"wall_clock={format_duration(float(wall_clock_seconds))}, "
f"avg_step={avg_seconds_per_step}s, "
f"avg_episode={avg_seconds_per_episode}s"
)
if steps_per_hour is not None or episodes_per_hour is not None:
six_hour_steps = round(float(steps_per_hour) * 6) if steps_per_hour is not None else None
seven_hour_steps = round(float(steps_per_hour) * 7) if steps_per_hour is not None else None
six_hour_episodes = round(float(episodes_per_hour) * 6) if episodes_per_hour is not None else None
seven_hour_episodes = round(float(episodes_per_hour) * 7) if episodes_per_hour is not None else None
print_info(
"Window estimate at current smoke throughput: "
f"6h -> steps={six_hour_steps}, episodes={six_hour_episodes}; "
f"7h -> steps={seven_hour_steps}, episodes={seven_hour_episodes}"
)
def ensure_no_active_training(base_url: str) -> None:
payload = fetch_training_status(base_url)
if payload.get("status") == "running":
raise ScriptError(
"The remote training manager already reports an active run. "
f"Current status: {summarize_training_status(payload)}"
)
def start_training(base_url: str, train_payload: dict[str, Any]) -> dict[str, Any]:
print_info(f"Starting training with payload: {pretty_json(train_payload)}")
payload = http_json("POST", api_url(base_url, "/train"), payload=train_payload)
print_info(f"Training accepted: {pretty_json(payload)}")
return payload
def poll_training_status(
base_url: str,
*,
poll_interval_seconds: int,
timeout_seconds: int,
) -> dict[str, Any]:
deadline = time.time() + timeout_seconds
last_signature: tuple[Any, ...] | None = None
while time.time() < deadline:
payload = fetch_training_status(base_url)
signature = (
payload.get("status"),
payload.get("phase"),
payload.get("completed_steps"),
payload.get("total_steps"),
payload.get("last_problem_family"),
payload.get("last_reward"),
payload.get("error"),
)
if signature != last_signature:
print_info(summarize_training_status(payload))
last_signature = signature
status = payload.get("status")
if status == "failed":
print_error("Training failed. Final payload follows.")
print(pretty_json(payload), flush=True)
return payload
if status == "succeeded":
print_info("Training succeeded. Final payload follows.")
print(pretty_json(payload), flush=True)
print_runtime_baseline(payload)
return payload
time.sleep(poll_interval_seconds)
raise ScriptError(f"Training did not finish within {timeout_seconds} seconds.")
def build_train_payload(args: argparse.Namespace) -> dict[str, Any]:
payload: dict[str, Any] = {"preset": args.preset}
if args.train_payload_json:
extra_payload = json.loads(args.train_payload_json)
if not isinstance(extra_payload, dict):
raise ScriptError("--train-payload-json must decode to a JSON object.")
payload.update(extra_payload)
return payload
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Deploy the current repo to a Hugging Face Space and run a smoke training job.",
)
parser.add_argument(
"--base-url",
required=True,
help="Base URL of the running server, for example https://<space>.hf.space or http://localhost:7860",
)
parser.add_argument(
"--preset",
default="smoke",
help="Training preset to send to /train. Defaults to smoke.",
)
parser.add_argument(
"--train-payload-json",
default=None,
help="Optional JSON object merged into the /train request body.",
)
parser.add_argument(
"--poll-interval-seconds",
type=int,
default=DEFAULT_POLL_INTERVAL_SECONDS,
help="How often to poll /health and /train/status.",
)
parser.add_argument(
"--deploy-timeout-seconds",
type=int,
default=DEFAULT_DEPLOY_TIMEOUT_SECONDS,
help="Maximum time to wait for the service to become healthy after push.",
)
parser.add_argument(
"--train-timeout-seconds",
type=int,
default=DEFAULT_TRAIN_TIMEOUT_SECONDS,
help="Maximum time to wait for the smoke train run to finish.",
)
parser.add_argument(
"--required-healthy-checks",
type=int,
default=DEFAULT_REQUIRED_HEALTHY_CHECKS,
help="Number of consecutive healthy /health checks required before training starts.",
)
parser.add_argument(
"--min-deploy-wait-seconds",
type=int,
default=DEFAULT_MIN_DEPLOY_WAIT_SECONDS,
help="Minimum time to wait after push before treating health as stable.",
)
parser.add_argument(
"--remote",
default=DEFAULT_REMOTE,
help="Git remote to push to when deploying.",
)
parser.add_argument(
"--remote-branch",
default=DEFAULT_REMOTE_BRANCH,
help="Remote branch to push HEAD to when deploying.",
)
parser.add_argument(
"--commit-message",
default=DEFAULT_COMMIT_MESSAGE,
help="Commit message to use if local changes need to be committed before push.",
)
parser.add_argument(
"--skip-commit",
action="store_true",
help="Do not auto-commit local changes before deploy.",
)
parser.add_argument(
"--skip-push",
action="store_true",
help="Skip git deploy entirely and just hit the running server.",
)
parser.add_argument(
"--skip-health-check",
action="store_true",
help="Skip waiting on /health before training.",
)
parser.add_argument(
"--trigger-only",
action="store_true",
help="Start the smoke run and exit without polling to completion.",
)
parser.add_argument(
"--status-only",
action="store_true",
help="Do not start a new run; just print /train/status and optionally poll it.",
)
parser.add_argument(
"--follow-running",
action="store_true",
help="If /train/status already reports a running job, follow it instead of failing.",
)
return parser
def maybe_deploy(args: argparse.Namespace, repo_root: Path) -> None:
if args.skip_push:
print_info("Skipping git deploy because --skip-push was set.")
return
git_options = GitOptions(
repo_root=repo_root,
remote=args.remote,
remote_branch=args.remote_branch,
commit_message=args.commit_message,
skip_commit=args.skip_commit,
skip_push=args.skip_push,
)
local_head_sha = commit_changes_if_needed(git_options)
if not deployment_needed(git_options, local_head_sha):
print_info("Skipping push because the remote is already on the current local HEAD.")
return
push_current_head(git_options)
def main(argv: list[str] | None = None) -> int:
args = build_parser().parse_args(argv)
repo_root = Path(__file__).resolve().parents[1]
try:
if args.status_only:
payload = fetch_training_status(args.base_url)
print(pretty_json(payload), flush=True)
if args.follow_running and payload.get("status") == "running":
final_status = poll_training_status(
args.base_url,
poll_interval_seconds=args.poll_interval_seconds,
timeout_seconds=args.train_timeout_seconds,
)
return 0 if final_status.get("status") == "succeeded" else 1
return 0
maybe_deploy(args, repo_root)
if not args.skip_health_check:
wait_for_health(
args.base_url,
timeout_seconds=args.deploy_timeout_seconds,
poll_interval_seconds=args.poll_interval_seconds,
required_healthy_checks=args.required_healthy_checks,
min_deploy_wait_seconds=args.min_deploy_wait_seconds,
)
else:
print_info("Skipping health wait because --skip-health-check was set.")
current_status = fetch_training_status(args.base_url)
if current_status.get("status") == "running":
if args.follow_running:
print_warn(
"A training job is already running; following the existing run instead of starting a new one."
)
final_status = poll_training_status(
args.base_url,
poll_interval_seconds=args.poll_interval_seconds,
timeout_seconds=args.train_timeout_seconds,
)
return 0 if final_status.get("status") == "succeeded" else 1
ensure_no_active_training(args.base_url)
train_payload = build_train_payload(args)
start_training(args.base_url, train_payload)
if args.trigger_only:
print_info("Training was triggered successfully; exiting because --trigger-only was set.")
return 0
final_status = poll_training_status(
args.base_url,
poll_interval_seconds=args.poll_interval_seconds,
timeout_seconds=args.train_timeout_seconds,
)
return 0 if final_status.get("status") == "succeeded" else 1
except Exception as exc:
print_error(str(exc))
return 1
if __name__ == "__main__":
raise SystemExit(main())