sage / serve /control_plane.py
sage002's picture
SAGE model repository : Updating some model checkpoints
773566b verified
"""Minimal remote control plane for the SAGE FastAPI server."""
from __future__ import annotations
import base64
from collections.abc import Callable, Iterator
from dataclasses import dataclass, field
import hashlib
import hmac
import json
import os
from pathlib import Path
import secrets
import shlex
import shutil
import signal
import string
import subprocess
import sys
import threading
import time
from typing import Any
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.responses import HTMLResponse, StreamingResponse
from pydantic import BaseModel, Field
REPO_ROOT = Path(__file__).resolve().parents[1]
STATIC_INDEX = REPO_ROOT / "serve" / "static" / "index.html"
SESSION_COOKIE = "sage_session"
SESSION_AGE_SECONDS = 60 * 60 * 12
PASSWORD_LENGTH = 12
_RUNTIME_PASSWORD: str | None = None
_RUNTIME_LOCAL_URL: str | None = None
_RUNTIME_PUBLIC_URL: str | None = None
@dataclass(frozen=True)
class PresetField:
"""One UI field for a preset action."""
name: str
label: str
kind: str = "text"
default: Any = ""
placeholder: str = ""
required: bool = False
@dataclass(frozen=True)
class CommandPreset:
"""One preset exposed in the browser UI."""
identifier: str
label: str
description: str
mode: str
fields: tuple[PresetField, ...] = ()
def to_dict(self) -> dict[str, Any]:
return {
"id": self.identifier,
"label": self.label,
"description": self.description,
"mode": self.mode,
"fields": [
{
"name": field.name,
"label": field.label,
"kind": field.kind,
"default": field.default,
"placeholder": field.placeholder,
"required": field.required,
}
for field in self.fields
],
}
class LoginRequest(BaseModel):
"""Login payload for the control UI."""
password: str
class RunCommandRequest(BaseModel):
"""Run either a preset action or a raw shell command."""
preset_id: str | None = None
args: dict[str, Any] = Field(default_factory=dict)
command: str | None = None
cwd: str | None = None
@dataclass
class CommandJob:
"""One tracked subprocess job."""
identifier: str
label: str
mode: str
command: str
cwd: str
status: str = "running"
exit_code: int | None = None
started_at: float = field(default_factory=time.time)
ended_at: float | None = None
stop_requested: bool = False
process: subprocess.Popen[str] | None = None
logs: list[str] = field(default_factory=list)
events: list[dict[str, Any]] = field(default_factory=list)
next_event_id: int = 0
condition: threading.Condition = field(default_factory=threading.Condition)
def emit(self, event: str, payload: dict[str, Any]) -> None:
with self.condition:
self.events.append({"id": self.next_event_id, "event": event, "data": payload})
self.next_event_id += 1
self.condition.notify_all()
def append_log(self, line: str) -> None:
clean = line.rstrip("\n")
self.logs.append(clean)
self.emit("log", {"line": clean})
def to_dict(self) -> dict[str, Any]:
return {
"id": self.identifier,
"label": self.label,
"mode": self.mode,
"command": self.command,
"cwd": self.cwd,
"status": self.status,
"exit_code": self.exit_code,
"started_at": self.started_at,
"ended_at": self.ended_at,
"log_lines": len(self.logs),
}
def _quote_shell(value: str) -> str:
if os.name == "nt":
return "'" + value.replace("'", "''") + "'"
return shlex.quote(value)
def _split_multi_value(value: Any) -> list[str]:
if value is None:
return []
if isinstance(value, list):
return [str(item).strip() for item in value if str(item).strip()]
text = str(value).replace(",", "\n")
return [item.strip() for item in text.splitlines() if item.strip()]
def _build_presets(enable_generate: bool) -> list[CommandPreset]:
presets = [
CommandPreset(
"health_check",
"Health Check",
"Call the local /health API and show the JSON response.",
"api",
),
CommandPreset(
"data_bootstrap",
"Bootstrap Dataset",
"Create small JSONL corpora under data/raw for tokenizer and smoke-training runs.",
"job",
(
PresetField("output_dir", "Output Dir", default="data/raw"),
PresetField("overwrite", "Overwrite Existing Files", kind="boolean", default=False),
),
),
CommandPreset(
"data_pipeline",
"Build Data Shards",
"Filter raw JSONL corpora, deduplicate them, then write parquet shards with the trained tokenizer.",
"job",
(
PresetField("tokenizer_model", "Tokenizer Model", default="tokenizer/tokenizer.model"),
PresetField("output_dir", "Output Dir", default="data/processed"),
PresetField(
"sources",
"Sources",
kind="textarea",
placeholder="general_web\ncode\nmath_science\nmultilingual\nsynthetic",
),
PresetField("shard_size", "Shard Size", kind="number", default=2048),
PresetField("limit_per_source", "Limit Per Source", kind="number", default=0),
),
),
CommandPreset(
"serve_gpu",
"Serve GPU",
"Start the GPU-oriented FastAPI server with uvicorn.",
"job",
(
PresetField("host", "Host", default="0.0.0.0"),
PresetField("port", "Port", kind="number", default=8000),
),
),
CommandPreset(
"serve_cpu",
"Serve CPU",
"Start the CPU-oriented FastAPI server with uvicorn.",
"job",
(
PresetField("host", "Host", default="0.0.0.0"),
PresetField("port", "Port", kind="number", default=8001),
),
),
CommandPreset(
"tokenizer_train",
"Tokenizer Train",
"Train the SentencePiece tokenizer from plain-text corpora.",
"job",
(
PresetField(
"input_paths",
"Input Paths",
kind="textarea",
placeholder="data/raw/general_web.jsonl\ndata/raw/code.jsonl",
required=True,
),
PresetField("model_prefix", "Model Prefix", default="tokenizer/tokenizer"),
PresetField("vocab_size", "Vocab Size", kind="number", default=50000),
PresetField("training_text", "Training Text", default="tokenizer/training_corpus.txt"),
),
),
CommandPreset(
"tokenizer_validate",
"Tokenizer Validate",
"Run the tokenizer smoke validation suite.",
"job",
(PresetField("model_path", "Model Path", default="tokenizer/tokenizer.model"),),
),
CommandPreset(
"training_run",
"Training Run",
"Launch the trainer with explicit shard and config paths.",
"job",
(
PresetField("model_config", "Model Config", default="configs/model/1b.yaml"),
PresetField("schedule_config", "Schedule Config", default="configs/train/schedule.yaml"),
PresetField(
"train_shards",
"Train Shards",
kind="textarea",
placeholder="data/processed/shard-00000.parquet",
required=True,
),
PresetField(
"validation_shards",
"Validation Shards",
kind="textarea",
placeholder="data/processed/shard-00001.parquet",
),
PresetField("output_dir", "Output Dir", default="runs/default"),
PresetField("steps", "Steps", kind="number", default=20),
PresetField("disable_wandb", "Disable W&B", kind="boolean", default=True),
),
),
CommandPreset(
"eval_run",
"Eval Run",
"Run the registered eval benchmarks.",
"job",
),
CommandPreset(
"git_status",
"Git Status",
"Show the current repository status.",
"job",
),
CommandPreset(
"git_commit_push",
"Git Add Commit Push",
"Add selected paths, create a commit, and push it to the remote branch.",
"shell",
(
PresetField(
"paths",
"Paths",
kind="textarea",
placeholder="serve\nserve/static\ntests\ntest.ipynb\nREADME.md",
required=True,
),
PresetField("commit_message", "Commit Message", placeholder="feat: add control UI", required=True),
PresetField("remote", "Remote", default="origin"),
PresetField("branch", "Branch", default="main"),
),
),
CommandPreset(
"hf_sync",
"Hugging Face Sync",
"Push the current folder contents to the configured Hugging Face repo.",
"job",
),
]
if enable_generate:
presets.insert(
1,
CommandPreset(
"generate",
"Generate",
"Call the local /generate API and show the token output.",
"api",
(
PresetField("input_ids", "Input IDs", kind="json", default=[1, 42, 99]),
PresetField("max_new_tokens", "Max New Tokens", kind="number", default=8),
),
),
)
return presets
class CommandManager:
"""Track subprocess commands and expose their logs."""
def __init__(self) -> None:
self._jobs: dict[str, CommandJob] = {}
self._lock = threading.Lock()
def list_jobs(self) -> list[dict[str, Any]]:
with self._lock:
jobs = sorted(self._jobs.values(), key=lambda item: item.started_at, reverse=True)
return [job.to_dict() for job in jobs]
def get_job(self, job_id: str) -> CommandJob:
with self._lock:
job = self._jobs.get(job_id)
if job is None:
raise KeyError(job_id)
return job
def reset_for_tests(self) -> None:
with self._lock:
jobs = list(self._jobs.values())
for job in jobs:
process = job.process
if process is not None and process.poll() is None:
try:
self._terminate_process(process)
except Exception:
pass
with self._lock:
self._jobs.clear()
def start_job(self, label: str, command: list[str] | str, cwd: str, mode: str) -> CommandJob:
cwd_path = self._resolve_cwd(cwd)
shell = isinstance(command, str)
popen_command = self._build_shell_command(command) if shell else list(command)
rendered = self._render_command(command)
process = subprocess.Popen(
popen_command,
cwd=str(cwd_path),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
encoding="utf-8",
errors="replace",
bufsize=1,
**self._process_group_kwargs(),
)
job = CommandJob(identifier=str(uuid4()), label=label, mode=mode, command=rendered, cwd=str(cwd_path), process=process)
job.emit("status", {"status": "running"})
with self._lock:
self._jobs[job.identifier] = job
threading.Thread(target=self._read_output, args=(job,), daemon=True).start()
return job
def stop_job(self, job_id: str) -> CommandJob:
job = self.get_job(job_id)
process = job.process
if process is None or process.poll() is not None:
return job
job.stop_requested = True
job.status = "stopping"
job.emit("status", {"status": "stopping"})
self._terminate_process(process)
threading.Thread(target=self._force_kill_if_needed, args=(job,), daemon=True).start()
return job
def _resolve_cwd(self, cwd: str) -> Path:
if not cwd:
return REPO_ROOT
requested = Path(cwd)
if not requested.is_absolute():
requested = REPO_ROOT / requested
return requested.resolve()
def _build_shell_command(self, command: str) -> list[str]:
if os.name == "nt":
return ["powershell", "-Command", command]
shell = "bash" if shutil.which("bash") else "sh"
return [shell, "-lc", command]
def _render_command(self, command: list[str] | str) -> str:
if isinstance(command, str):
return command
if os.name == "nt":
return subprocess.list2cmdline(command)
return shlex.join(command)
def _process_group_kwargs(self) -> dict[str, Any]:
if os.name == "nt":
return {"creationflags": subprocess.CREATE_NEW_PROCESS_GROUP}
return {"start_new_session": True}
def _terminate_process(self, process: subprocess.Popen[str]) -> None:
if os.name == "nt":
process.terminate()
return
os.killpg(process.pid, signal.SIGTERM)
def _kill_process(self, process: subprocess.Popen[str]) -> None:
if os.name == "nt":
process.kill()
return
os.killpg(process.pid, signal.SIGKILL)
def _force_kill_if_needed(self, job: CommandJob) -> None:
process = job.process
if process is None:
return
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
self._kill_process(process)
def _read_output(self, job: CommandJob) -> None:
process = job.process
if process is None:
return
stream = process.stdout
if stream is not None:
for line in iter(stream.readline, ""):
if line == "":
break
job.append_log(line)
return_code = process.wait()
job.exit_code = return_code
job.ended_at = time.time()
if job.stop_requested:
job.status = "stopped"
elif return_code == 0:
job.status = "completed"
else:
job.status = "failed"
job.emit("status", {"status": job.status, "exit_code": return_code})
CONTROL_MANAGER = CommandManager()
def _get_password() -> str | None:
global _RUNTIME_PASSWORD
if _RUNTIME_PASSWORD is None:
env_password = os.environ.get("SAGE_WEB_PASSWORD")
if env_password:
_RUNTIME_PASSWORD = env_password
else:
alphabet = string.ascii_letters + string.digits
_RUNTIME_PASSWORD = "".join(secrets.choice(alphabet) for _ in range(PASSWORD_LENGTH))
return _RUNTIME_PASSWORD
def get_runtime_access_info() -> dict[str, str | None]:
"""Return the current runtime login password and access URLs."""
return {
"password": _get_password(),
"local_url": _RUNTIME_LOCAL_URL,
"public_url": _RUNTIME_PUBLIC_URL,
}
def set_runtime_access_urls(local_url: str | None = None, public_url: str | None = None) -> None:
"""Record the URLs that should be shown in the startup banner."""
global _RUNTIME_LOCAL_URL, _RUNTIME_PUBLIC_URL
_RUNTIME_LOCAL_URL = local_url
_RUNTIME_PUBLIC_URL = public_url
def _get_signing_secret() -> str:
return os.environ.get("SAGE_WEB_SECRET") or _get_password() or "sage-control-plane"
def _encode_cookie_payload(payload: dict[str, Any]) -> str:
raw = json.dumps(payload, separators=(",", ":"), sort_keys=True).encode("utf-8")
body = base64.urlsafe_b64encode(raw).decode("ascii")
digest = hmac.new(_get_signing_secret().encode("utf-8"), raw, hashlib.sha256).hexdigest()
return f"{body}.{digest}"
def _decode_cookie_payload(token: str | None) -> dict[str, Any] | None:
if not token or "." not in token:
return None
body, signature = token.split(".", 1)
try:
raw = base64.urlsafe_b64decode(body.encode("ascii"))
except Exception:
return None
expected = hmac.new(_get_signing_secret().encode("utf-8"), raw, hashlib.sha256).hexdigest()
if not hmac.compare_digest(signature, expected):
return None
payload = json.loads(raw.decode("utf-8"))
issued_at = float(payload.get("iat", 0))
if time.time() - issued_at > SESSION_AGE_SECONDS:
return None
return payload
def _require_session(request: Request) -> dict[str, Any]:
payload = _decode_cookie_payload(request.cookies.get(SESSION_COOKIE))
if payload is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required.")
return payload
def _parse_number(value: Any, default: int) -> int:
if value in (None, ""):
return default
return int(value)
def _api_response(handler: Callable[[dict[str, Any]], dict[str, Any]], args: dict[str, Any]) -> dict[str, Any]:
return {"kind": "api", "result": handler(args)}
def _validate_preset_args(preset: CommandPreset, args: dict[str, Any]) -> None:
missing: list[str] = []
for field in preset.fields:
if not field.required:
continue
value = args.get(field.name)
if value is None:
missing.append(field.label)
continue
if isinstance(value, str) and not value.strip():
missing.append(field.label)
continue
if isinstance(value, list) and not value:
missing.append(field.label)
if missing:
raise HTTPException(status_code=400, detail=f"Missing required fields: {', '.join(missing)}")
def _build_command_for_preset(preset_id: str, args: dict[str, Any]) -> list[str] | str:
if preset_id == "data_bootstrap":
command = [sys.executable, "-m", "data.bootstrap", "--output-dir", str(args.get("output_dir") or "data/raw")]
if bool(args.get("overwrite", False)):
command.append("--overwrite")
return command
if preset_id == "data_pipeline":
command = [
sys.executable,
"-m",
"data.pipeline",
"--tokenizer-model",
str(args.get("tokenizer_model") or "tokenizer/tokenizer.model"),
"--output-dir",
str(args.get("output_dir") or "data/processed"),
"--shard-size",
str(_parse_number(args.get("shard_size"), 2048)),
]
sources = _split_multi_value(args.get("sources"))
if sources:
command.extend(["--sources", *sources])
limit_per_source = _parse_number(args.get("limit_per_source"), 0)
if limit_per_source > 0:
command.extend(["--limit-per-source", str(limit_per_source)])
return command
if preset_id == "serve_gpu":
return [
sys.executable,
"-m",
"uvicorn",
"serve.server:app",
"--host",
str(args.get("host") or "0.0.0.0"),
"--port",
str(_parse_number(args.get("port"), 8000)),
]
if preset_id == "serve_cpu":
return [
sys.executable,
"-m",
"uvicorn",
"serve.server_cpu:app",
"--host",
str(args.get("host") or "0.0.0.0"),
"--port",
str(_parse_number(args.get("port"), 8001)),
]
if preset_id == "tokenizer_train":
input_paths = _split_multi_value(args.get("input_paths"))
if not input_paths:
raise HTTPException(status_code=400, detail="Tokenizer training requires at least one input path.")
command = [
sys.executable,
"-m",
"tokenizer.train_tokenizer",
"--input",
*input_paths,
"--model-prefix",
str(args.get("model_prefix") or "tokenizer/tokenizer"),
"--vocab-size",
str(_parse_number(args.get("vocab_size"), 50000)),
"--training-text",
str(args.get("training_text") or "tokenizer/training_corpus.txt"),
]
return command
if preset_id == "tokenizer_validate":
return [sys.executable, "-m", "tokenizer.validate_tokenizer", str(args.get("model_path") or "tokenizer/tokenizer.model")]
if preset_id == "training_run":
train_shards = _split_multi_value(args.get("train_shards"))
if not train_shards:
raise HTTPException(status_code=400, detail="Training run requires at least one training shard.")
command = [
sys.executable,
"-m",
"train.trainer",
"--model-config",
str(args.get("model_config") or "configs/model/1b.yaml"),
"--schedule-config",
str(args.get("schedule_config") or "configs/train/schedule.yaml"),
"--train-shards",
*train_shards,
"--output-dir",
str(args.get("output_dir") or "runs/default"),
"--steps",
str(_parse_number(args.get("steps"), 20)),
]
validation_shards = _split_multi_value(args.get("validation_shards"))
if validation_shards:
command.extend(["--validation-shards", *validation_shards])
if bool(args.get("disable_wandb", True)):
command.append("--disable-wandb")
return command
if preset_id == "eval_run":
return [sys.executable, "-m", "eval.run_benchmarks"]
if preset_id == "git_status":
return ["git", "status", "--short", "--branch"]
if preset_id == "hf_sync":
return [sys.executable, "hf_push.py"]
if preset_id == "git_commit_push":
paths = _split_multi_value(args.get("paths"))
if not paths:
raise HTTPException(status_code=400, detail="Git add/commit/push requires explicit paths.")
message = str(args.get("commit_message") or "").strip()
if not message:
raise HTTPException(status_code=400, detail="Git add/commit/push requires a commit message.")
remote = str(args.get("remote") or "origin")
branch = str(args.get("branch") or "main")
add_paths = " ".join(_quote_shell(path) for path in paths)
if os.name == "nt":
return (
f"git add -- {add_paths}; "
f"if ($LASTEXITCODE -ne 0) {{ exit $LASTEXITCODE }}; "
f"git commit -m {_quote_shell(message)}; "
f"if ($LASTEXITCODE -ne 0) {{ exit $LASTEXITCODE }}; "
f"git push {_quote_shell(remote)} {_quote_shell(branch)}; "
"exit $LASTEXITCODE"
)
return f"git add -- {add_paths} && git commit -m {_quote_shell(message)} && git push {_quote_shell(remote)} {_quote_shell(branch)}"
raise HTTPException(status_code=404, detail=f"Unknown preset: {preset_id}")
def build_control_router(api_handlers: dict[str, Callable[[dict[str, Any]], dict[str, Any]]]) -> APIRouter:
"""Create the shared HTML UI + command control router."""
router = APIRouter()
presets = _build_presets(enable_generate="generate" in api_handlers)
preset_map = {item.identifier: item for item in presets}
@router.get("/", response_class=HTMLResponse)
def index() -> str:
return STATIC_INDEX.read_text(encoding="utf-8")
@router.post("/api/login")
def login(payload: LoginRequest, response: Response) -> dict[str, Any]:
password = _get_password()
if not hmac.compare_digest(payload.password, password):
raise HTTPException(status_code=401, detail="Invalid password.")
token = _encode_cookie_payload({"iat": time.time(), "nonce": secrets.token_hex(8)})
response.set_cookie(
SESSION_COOKIE,
token,
max_age=SESSION_AGE_SECONDS,
httponly=True,
samesite="lax",
)
return {"success": True}
@router.get("/api/commands/presets")
def list_presets(_: dict[str, Any] = Depends(_require_session)) -> dict[str, Any]:
return {"presets": [preset.to_dict() for preset in presets], "repo_root": str(REPO_ROOT)}
@router.post("/api/commands/run")
def run_command(payload: RunCommandRequest, _: dict[str, Any] = Depends(_require_session)) -> dict[str, Any]:
if payload.preset_id:
preset = preset_map.get(payload.preset_id)
if preset is None:
raise HTTPException(status_code=404, detail=f"Unknown preset: {payload.preset_id}")
_validate_preset_args(preset, payload.args)
if preset.mode == "api":
handler = api_handlers.get(payload.preset_id)
if handler is None:
raise HTTPException(status_code=400, detail=f"Preset {payload.preset_id} is not available on this server.")
return _api_response(handler, payload.args)
command = _build_command_for_preset(payload.preset_id, payload.args)
mode = "shell" if isinstance(command, str) else "job"
try:
job = CONTROL_MANAGER.start_job(preset.label, command, cwd=str(REPO_ROOT), mode=mode)
except OSError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return {"kind": "job", "job": job.to_dict()}
if payload.command:
cwd = payload.cwd or str(REPO_ROOT)
try:
job = CONTROL_MANAGER.start_job("Raw Command", payload.command, cwd=cwd, mode="shell")
except OSError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return {"kind": "job", "job": job.to_dict()}
raise HTTPException(status_code=400, detail="Provide either preset_id or command.")
@router.get("/api/jobs")
def list_jobs(_: dict[str, Any] = Depends(_require_session)) -> dict[str, Any]:
return {"jobs": CONTROL_MANAGER.list_jobs()}
@router.get("/api/jobs/{job_id}")
def get_job(job_id: str, _: dict[str, Any] = Depends(_require_session)) -> dict[str, Any]:
try:
job = CONTROL_MANAGER.get_job(job_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail="Job not found.") from exc
return {"job": job.to_dict(), "logs": job.logs[-200:]}
@router.get("/api/jobs/{job_id}/stream")
def stream_job(job_id: str, _: dict[str, Any] = Depends(_require_session)) -> StreamingResponse:
try:
job = CONTROL_MANAGER.get_job(job_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail="Job not found.") from exc
def event_stream() -> Iterator[str]:
index = 0
while True:
heartbeat = False
with job.condition:
while index >= len(job.events) and job.status in {"running", "stopping"}:
job.condition.wait(timeout=15)
if index >= len(job.events) and job.status in {"running", "stopping"}:
heartbeat = True
break
pending = job.events[index:]
if heartbeat:
yield ": keep-alive\n\n"
continue
for item in pending:
index = item["id"] + 1
payload = json.dumps(item["data"])
yield f"id: {item['id']}\nevent: {item['event']}\ndata: {payload}\n\n"
if job.status not in {"running", "stopping"} and index >= len(job.events):
break
return StreamingResponse(event_stream(), media_type="text/event-stream")
@router.post("/api/jobs/{job_id}/stop")
def stop_job(job_id: str, _: dict[str, Any] = Depends(_require_session)) -> dict[str, Any]:
try:
job = CONTROL_MANAGER.stop_job(job_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail="Job not found.") from exc
return {"job": job.to_dict()}
return router