| """Minimal remote control plane for the SAGE FastAPI server.""" |
|
|
| from __future__ import annotations |
|
|
| import base64 |
| from collections.abc import Callable, Iterator |
| from dataclasses import dataclass, field |
| import hashlib |
| import hmac |
| import json |
| import os |
| from pathlib import Path |
| import secrets |
| import shlex |
| import shutil |
| import signal |
| import string |
| import subprocess |
| import sys |
| import threading |
| import time |
| from typing import Any |
| from uuid import uuid4 |
|
|
| from fastapi import APIRouter, Depends, HTTPException, Request, Response, status |
| from fastapi.responses import HTMLResponse, StreamingResponse |
| from pydantic import BaseModel, Field |
|
|
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| STATIC_INDEX = REPO_ROOT / "serve" / "static" / "index.html" |
| SESSION_COOKIE = "sage_session" |
| SESSION_AGE_SECONDS = 60 * 60 * 12 |
| PASSWORD_LENGTH = 12 |
| _RUNTIME_PASSWORD: str | None = None |
| _RUNTIME_LOCAL_URL: str | None = None |
| _RUNTIME_PUBLIC_URL: str | None = None |
|
|
|
|
| @dataclass(frozen=True) |
| class PresetField: |
| """One UI field for a preset action.""" |
|
|
| name: str |
| label: str |
| kind: str = "text" |
| default: Any = "" |
| placeholder: str = "" |
| required: bool = False |
|
|
|
|
| @dataclass(frozen=True) |
| class CommandPreset: |
| """One preset exposed in the browser UI.""" |
|
|
| identifier: str |
| label: str |
| description: str |
| mode: str |
| fields: tuple[PresetField, ...] = () |
|
|
| def to_dict(self) -> dict[str, Any]: |
| return { |
| "id": self.identifier, |
| "label": self.label, |
| "description": self.description, |
| "mode": self.mode, |
| "fields": [ |
| { |
| "name": field.name, |
| "label": field.label, |
| "kind": field.kind, |
| "default": field.default, |
| "placeholder": field.placeholder, |
| "required": field.required, |
| } |
| for field in self.fields |
| ], |
| } |
|
|
|
|
| class LoginRequest(BaseModel): |
| """Login payload for the control UI.""" |
|
|
| password: str |
|
|
|
|
| class RunCommandRequest(BaseModel): |
| """Run either a preset action or a raw shell command.""" |
|
|
| preset_id: str | None = None |
| args: dict[str, Any] = Field(default_factory=dict) |
| command: str | None = None |
| cwd: str | None = None |
|
|
|
|
| @dataclass |
| class CommandJob: |
| """One tracked subprocess job.""" |
|
|
| identifier: str |
| label: str |
| mode: str |
| command: str |
| cwd: str |
| status: str = "running" |
| exit_code: int | None = None |
| started_at: float = field(default_factory=time.time) |
| ended_at: float | None = None |
| stop_requested: bool = False |
| process: subprocess.Popen[str] | None = None |
| logs: list[str] = field(default_factory=list) |
| events: list[dict[str, Any]] = field(default_factory=list) |
| next_event_id: int = 0 |
| condition: threading.Condition = field(default_factory=threading.Condition) |
|
|
| def emit(self, event: str, payload: dict[str, Any]) -> None: |
| with self.condition: |
| self.events.append({"id": self.next_event_id, "event": event, "data": payload}) |
| self.next_event_id += 1 |
| self.condition.notify_all() |
|
|
| def append_log(self, line: str) -> None: |
| clean = line.rstrip("\n") |
| self.logs.append(clean) |
| self.emit("log", {"line": clean}) |
|
|
| def to_dict(self) -> dict[str, Any]: |
| return { |
| "id": self.identifier, |
| "label": self.label, |
| "mode": self.mode, |
| "command": self.command, |
| "cwd": self.cwd, |
| "status": self.status, |
| "exit_code": self.exit_code, |
| "started_at": self.started_at, |
| "ended_at": self.ended_at, |
| "log_lines": len(self.logs), |
| } |
|
|
|
|
| def _quote_shell(value: str) -> str: |
| if os.name == "nt": |
| return "'" + value.replace("'", "''") + "'" |
| return shlex.quote(value) |
|
|
|
|
| def _split_multi_value(value: Any) -> list[str]: |
| if value is None: |
| return [] |
| if isinstance(value, list): |
| return [str(item).strip() for item in value if str(item).strip()] |
| text = str(value).replace(",", "\n") |
| return [item.strip() for item in text.splitlines() if item.strip()] |
|
|
|
|
| def _build_presets(enable_generate: bool) -> list[CommandPreset]: |
| presets = [ |
| CommandPreset( |
| "health_check", |
| "Health Check", |
| "Call the local /health API and show the JSON response.", |
| "api", |
| ), |
| CommandPreset( |
| "data_bootstrap", |
| "Bootstrap Dataset", |
| "Create small JSONL corpora under data/raw for tokenizer and smoke-training runs.", |
| "job", |
| ( |
| PresetField("output_dir", "Output Dir", default="data/raw"), |
| PresetField("overwrite", "Overwrite Existing Files", kind="boolean", default=False), |
| ), |
| ), |
| CommandPreset( |
| "data_pipeline", |
| "Build Data Shards", |
| "Filter raw JSONL corpora, deduplicate them, then write parquet shards with the trained tokenizer.", |
| "job", |
| ( |
| PresetField("tokenizer_model", "Tokenizer Model", default="tokenizer/tokenizer.model"), |
| PresetField("output_dir", "Output Dir", default="data/processed"), |
| PresetField( |
| "sources", |
| "Sources", |
| kind="textarea", |
| placeholder="general_web\ncode\nmath_science\nmultilingual\nsynthetic", |
| ), |
| PresetField("shard_size", "Shard Size", kind="number", default=2048), |
| PresetField("limit_per_source", "Limit Per Source", kind="number", default=0), |
| ), |
| ), |
| CommandPreset( |
| "serve_gpu", |
| "Serve GPU", |
| "Start the GPU-oriented FastAPI server with uvicorn.", |
| "job", |
| ( |
| PresetField("host", "Host", default="0.0.0.0"), |
| PresetField("port", "Port", kind="number", default=8000), |
| ), |
| ), |
| CommandPreset( |
| "serve_cpu", |
| "Serve CPU", |
| "Start the CPU-oriented FastAPI server with uvicorn.", |
| "job", |
| ( |
| PresetField("host", "Host", default="0.0.0.0"), |
| PresetField("port", "Port", kind="number", default=8001), |
| ), |
| ), |
| CommandPreset( |
| "tokenizer_train", |
| "Tokenizer Train", |
| "Train the SentencePiece tokenizer from plain-text corpora.", |
| "job", |
| ( |
| PresetField( |
| "input_paths", |
| "Input Paths", |
| kind="textarea", |
| placeholder="data/raw/general_web.jsonl\ndata/raw/code.jsonl", |
| required=True, |
| ), |
| PresetField("model_prefix", "Model Prefix", default="tokenizer/tokenizer"), |
| PresetField("vocab_size", "Vocab Size", kind="number", default=50000), |
| PresetField("training_text", "Training Text", default="tokenizer/training_corpus.txt"), |
| ), |
| ), |
| CommandPreset( |
| "tokenizer_validate", |
| "Tokenizer Validate", |
| "Run the tokenizer smoke validation suite.", |
| "job", |
| (PresetField("model_path", "Model Path", default="tokenizer/tokenizer.model"),), |
| ), |
| CommandPreset( |
| "training_run", |
| "Training Run", |
| "Launch the trainer with explicit shard and config paths.", |
| "job", |
| ( |
| PresetField("model_config", "Model Config", default="configs/model/1b.yaml"), |
| PresetField("schedule_config", "Schedule Config", default="configs/train/schedule.yaml"), |
| PresetField( |
| "train_shards", |
| "Train Shards", |
| kind="textarea", |
| placeholder="data/processed/shard-00000.parquet", |
| required=True, |
| ), |
| PresetField( |
| "validation_shards", |
| "Validation Shards", |
| kind="textarea", |
| placeholder="data/processed/shard-00001.parquet", |
| ), |
| PresetField("output_dir", "Output Dir", default="runs/default"), |
| PresetField("steps", "Steps", kind="number", default=20), |
| PresetField("disable_wandb", "Disable W&B", kind="boolean", default=True), |
| ), |
| ), |
| CommandPreset( |
| "eval_run", |
| "Eval Run", |
| "Run the registered eval benchmarks.", |
| "job", |
| ), |
| CommandPreset( |
| "git_status", |
| "Git Status", |
| "Show the current repository status.", |
| "job", |
| ), |
| CommandPreset( |
| "git_commit_push", |
| "Git Add Commit Push", |
| "Add selected paths, create a commit, and push it to the remote branch.", |
| "shell", |
| ( |
| PresetField( |
| "paths", |
| "Paths", |
| kind="textarea", |
| placeholder="serve\nserve/static\ntests\ntest.ipynb\nREADME.md", |
| required=True, |
| ), |
| PresetField("commit_message", "Commit Message", placeholder="feat: add control UI", required=True), |
| PresetField("remote", "Remote", default="origin"), |
| PresetField("branch", "Branch", default="main"), |
| ), |
| ), |
| CommandPreset( |
| "hf_sync", |
| "Hugging Face Sync", |
| "Push the current folder contents to the configured Hugging Face repo.", |
| "job", |
| ), |
| ] |
| if enable_generate: |
| presets.insert( |
| 1, |
| CommandPreset( |
| "generate", |
| "Generate", |
| "Call the local /generate API and show the token output.", |
| "api", |
| ( |
| PresetField("input_ids", "Input IDs", kind="json", default=[1, 42, 99]), |
| PresetField("max_new_tokens", "Max New Tokens", kind="number", default=8), |
| ), |
| ), |
| ) |
| return presets |
|
|
|
|
| class CommandManager: |
| """Track subprocess commands and expose their logs.""" |
|
|
| def __init__(self) -> None: |
| self._jobs: dict[str, CommandJob] = {} |
| self._lock = threading.Lock() |
|
|
| def list_jobs(self) -> list[dict[str, Any]]: |
| with self._lock: |
| jobs = sorted(self._jobs.values(), key=lambda item: item.started_at, reverse=True) |
| return [job.to_dict() for job in jobs] |
|
|
| def get_job(self, job_id: str) -> CommandJob: |
| with self._lock: |
| job = self._jobs.get(job_id) |
| if job is None: |
| raise KeyError(job_id) |
| return job |
|
|
| def reset_for_tests(self) -> None: |
| with self._lock: |
| jobs = list(self._jobs.values()) |
| for job in jobs: |
| process = job.process |
| if process is not None and process.poll() is None: |
| try: |
| self._terminate_process(process) |
| except Exception: |
| pass |
| with self._lock: |
| self._jobs.clear() |
|
|
| def start_job(self, label: str, command: list[str] | str, cwd: str, mode: str) -> CommandJob: |
| cwd_path = self._resolve_cwd(cwd) |
| shell = isinstance(command, str) |
| popen_command = self._build_shell_command(command) if shell else list(command) |
| rendered = self._render_command(command) |
| process = subprocess.Popen( |
| popen_command, |
| cwd=str(cwd_path), |
| stdout=subprocess.PIPE, |
| stderr=subprocess.STDOUT, |
| text=True, |
| encoding="utf-8", |
| errors="replace", |
| bufsize=1, |
| **self._process_group_kwargs(), |
| ) |
| job = CommandJob(identifier=str(uuid4()), label=label, mode=mode, command=rendered, cwd=str(cwd_path), process=process) |
| job.emit("status", {"status": "running"}) |
| with self._lock: |
| self._jobs[job.identifier] = job |
| threading.Thread(target=self._read_output, args=(job,), daemon=True).start() |
| return job |
|
|
| def stop_job(self, job_id: str) -> CommandJob: |
| job = self.get_job(job_id) |
| process = job.process |
| if process is None or process.poll() is not None: |
| return job |
| job.stop_requested = True |
| job.status = "stopping" |
| job.emit("status", {"status": "stopping"}) |
| self._terminate_process(process) |
| threading.Thread(target=self._force_kill_if_needed, args=(job,), daemon=True).start() |
| return job |
|
|
| def _resolve_cwd(self, cwd: str) -> Path: |
| if not cwd: |
| return REPO_ROOT |
| requested = Path(cwd) |
| if not requested.is_absolute(): |
| requested = REPO_ROOT / requested |
| return requested.resolve() |
|
|
| def _build_shell_command(self, command: str) -> list[str]: |
| if os.name == "nt": |
| return ["powershell", "-Command", command] |
| shell = "bash" if shutil.which("bash") else "sh" |
| return [shell, "-lc", command] |
|
|
| def _render_command(self, command: list[str] | str) -> str: |
| if isinstance(command, str): |
| return command |
| if os.name == "nt": |
| return subprocess.list2cmdline(command) |
| return shlex.join(command) |
|
|
| def _process_group_kwargs(self) -> dict[str, Any]: |
| if os.name == "nt": |
| return {"creationflags": subprocess.CREATE_NEW_PROCESS_GROUP} |
| return {"start_new_session": True} |
|
|
| def _terminate_process(self, process: subprocess.Popen[str]) -> None: |
| if os.name == "nt": |
| process.terminate() |
| return |
| os.killpg(process.pid, signal.SIGTERM) |
|
|
| def _kill_process(self, process: subprocess.Popen[str]) -> None: |
| if os.name == "nt": |
| process.kill() |
| return |
| os.killpg(process.pid, signal.SIGKILL) |
|
|
| def _force_kill_if_needed(self, job: CommandJob) -> None: |
| process = job.process |
| if process is None: |
| return |
| try: |
| process.wait(timeout=5) |
| except subprocess.TimeoutExpired: |
| self._kill_process(process) |
|
|
| def _read_output(self, job: CommandJob) -> None: |
| process = job.process |
| if process is None: |
| return |
| stream = process.stdout |
| if stream is not None: |
| for line in iter(stream.readline, ""): |
| if line == "": |
| break |
| job.append_log(line) |
| return_code = process.wait() |
| job.exit_code = return_code |
| job.ended_at = time.time() |
| if job.stop_requested: |
| job.status = "stopped" |
| elif return_code == 0: |
| job.status = "completed" |
| else: |
| job.status = "failed" |
| job.emit("status", {"status": job.status, "exit_code": return_code}) |
|
|
|
|
| CONTROL_MANAGER = CommandManager() |
|
|
|
|
| def _get_password() -> str | None: |
| global _RUNTIME_PASSWORD |
| if _RUNTIME_PASSWORD is None: |
| env_password = os.environ.get("SAGE_WEB_PASSWORD") |
| if env_password: |
| _RUNTIME_PASSWORD = env_password |
| else: |
| alphabet = string.ascii_letters + string.digits |
| _RUNTIME_PASSWORD = "".join(secrets.choice(alphabet) for _ in range(PASSWORD_LENGTH)) |
| return _RUNTIME_PASSWORD |
|
|
|
|
| def get_runtime_access_info() -> dict[str, str | None]: |
| """Return the current runtime login password and access URLs.""" |
| return { |
| "password": _get_password(), |
| "local_url": _RUNTIME_LOCAL_URL, |
| "public_url": _RUNTIME_PUBLIC_URL, |
| } |
|
|
|
|
| def set_runtime_access_urls(local_url: str | None = None, public_url: str | None = None) -> None: |
| """Record the URLs that should be shown in the startup banner.""" |
| global _RUNTIME_LOCAL_URL, _RUNTIME_PUBLIC_URL |
| _RUNTIME_LOCAL_URL = local_url |
| _RUNTIME_PUBLIC_URL = public_url |
|
|
|
|
| def _get_signing_secret() -> str: |
| return os.environ.get("SAGE_WEB_SECRET") or _get_password() or "sage-control-plane" |
|
|
|
|
| def _encode_cookie_payload(payload: dict[str, Any]) -> str: |
| raw = json.dumps(payload, separators=(",", ":"), sort_keys=True).encode("utf-8") |
| body = base64.urlsafe_b64encode(raw).decode("ascii") |
| digest = hmac.new(_get_signing_secret().encode("utf-8"), raw, hashlib.sha256).hexdigest() |
| return f"{body}.{digest}" |
|
|
|
|
| def _decode_cookie_payload(token: str | None) -> dict[str, Any] | None: |
| if not token or "." not in token: |
| return None |
| body, signature = token.split(".", 1) |
| try: |
| raw = base64.urlsafe_b64decode(body.encode("ascii")) |
| except Exception: |
| return None |
| expected = hmac.new(_get_signing_secret().encode("utf-8"), raw, hashlib.sha256).hexdigest() |
| if not hmac.compare_digest(signature, expected): |
| return None |
| payload = json.loads(raw.decode("utf-8")) |
| issued_at = float(payload.get("iat", 0)) |
| if time.time() - issued_at > SESSION_AGE_SECONDS: |
| return None |
| return payload |
|
|
|
|
| def _require_session(request: Request) -> dict[str, Any]: |
| payload = _decode_cookie_payload(request.cookies.get(SESSION_COOKIE)) |
| if payload is None: |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required.") |
| return payload |
|
|
|
|
| def _parse_number(value: Any, default: int) -> int: |
| if value in (None, ""): |
| return default |
| return int(value) |
|
|
|
|
| def _api_response(handler: Callable[[dict[str, Any]], dict[str, Any]], args: dict[str, Any]) -> dict[str, Any]: |
| return {"kind": "api", "result": handler(args)} |
|
|
|
|
| def _validate_preset_args(preset: CommandPreset, args: dict[str, Any]) -> None: |
| missing: list[str] = [] |
| for field in preset.fields: |
| if not field.required: |
| continue |
| value = args.get(field.name) |
| if value is None: |
| missing.append(field.label) |
| continue |
| if isinstance(value, str) and not value.strip(): |
| missing.append(field.label) |
| continue |
| if isinstance(value, list) and not value: |
| missing.append(field.label) |
| if missing: |
| raise HTTPException(status_code=400, detail=f"Missing required fields: {', '.join(missing)}") |
|
|
|
|
| def _build_command_for_preset(preset_id: str, args: dict[str, Any]) -> list[str] | str: |
| if preset_id == "data_bootstrap": |
| command = [sys.executable, "-m", "data.bootstrap", "--output-dir", str(args.get("output_dir") or "data/raw")] |
| if bool(args.get("overwrite", False)): |
| command.append("--overwrite") |
| return command |
| if preset_id == "data_pipeline": |
| command = [ |
| sys.executable, |
| "-m", |
| "data.pipeline", |
| "--tokenizer-model", |
| str(args.get("tokenizer_model") or "tokenizer/tokenizer.model"), |
| "--output-dir", |
| str(args.get("output_dir") or "data/processed"), |
| "--shard-size", |
| str(_parse_number(args.get("shard_size"), 2048)), |
| ] |
| sources = _split_multi_value(args.get("sources")) |
| if sources: |
| command.extend(["--sources", *sources]) |
| limit_per_source = _parse_number(args.get("limit_per_source"), 0) |
| if limit_per_source > 0: |
| command.extend(["--limit-per-source", str(limit_per_source)]) |
| return command |
| if preset_id == "serve_gpu": |
| return [ |
| sys.executable, |
| "-m", |
| "uvicorn", |
| "serve.server:app", |
| "--host", |
| str(args.get("host") or "0.0.0.0"), |
| "--port", |
| str(_parse_number(args.get("port"), 8000)), |
| ] |
| if preset_id == "serve_cpu": |
| return [ |
| sys.executable, |
| "-m", |
| "uvicorn", |
| "serve.server_cpu:app", |
| "--host", |
| str(args.get("host") or "0.0.0.0"), |
| "--port", |
| str(_parse_number(args.get("port"), 8001)), |
| ] |
| if preset_id == "tokenizer_train": |
| input_paths = _split_multi_value(args.get("input_paths")) |
| if not input_paths: |
| raise HTTPException(status_code=400, detail="Tokenizer training requires at least one input path.") |
| command = [ |
| sys.executable, |
| "-m", |
| "tokenizer.train_tokenizer", |
| "--input", |
| *input_paths, |
| "--model-prefix", |
| str(args.get("model_prefix") or "tokenizer/tokenizer"), |
| "--vocab-size", |
| str(_parse_number(args.get("vocab_size"), 50000)), |
| "--training-text", |
| str(args.get("training_text") or "tokenizer/training_corpus.txt"), |
| ] |
| return command |
| if preset_id == "tokenizer_validate": |
| return [sys.executable, "-m", "tokenizer.validate_tokenizer", str(args.get("model_path") or "tokenizer/tokenizer.model")] |
| if preset_id == "training_run": |
| train_shards = _split_multi_value(args.get("train_shards")) |
| if not train_shards: |
| raise HTTPException(status_code=400, detail="Training run requires at least one training shard.") |
| command = [ |
| sys.executable, |
| "-m", |
| "train.trainer", |
| "--model-config", |
| str(args.get("model_config") or "configs/model/1b.yaml"), |
| "--schedule-config", |
| str(args.get("schedule_config") or "configs/train/schedule.yaml"), |
| "--train-shards", |
| *train_shards, |
| "--output-dir", |
| str(args.get("output_dir") or "runs/default"), |
| "--steps", |
| str(_parse_number(args.get("steps"), 20)), |
| ] |
| validation_shards = _split_multi_value(args.get("validation_shards")) |
| if validation_shards: |
| command.extend(["--validation-shards", *validation_shards]) |
| if bool(args.get("disable_wandb", True)): |
| command.append("--disable-wandb") |
| return command |
| if preset_id == "eval_run": |
| return [sys.executable, "-m", "eval.run_benchmarks"] |
| if preset_id == "git_status": |
| return ["git", "status", "--short", "--branch"] |
| if preset_id == "hf_sync": |
| return [sys.executable, "hf_push.py"] |
| if preset_id == "git_commit_push": |
| paths = _split_multi_value(args.get("paths")) |
| if not paths: |
| raise HTTPException(status_code=400, detail="Git add/commit/push requires explicit paths.") |
| message = str(args.get("commit_message") or "").strip() |
| if not message: |
| raise HTTPException(status_code=400, detail="Git add/commit/push requires a commit message.") |
| remote = str(args.get("remote") or "origin") |
| branch = str(args.get("branch") or "main") |
| add_paths = " ".join(_quote_shell(path) for path in paths) |
| if os.name == "nt": |
| return ( |
| f"git add -- {add_paths}; " |
| f"if ($LASTEXITCODE -ne 0) {{ exit $LASTEXITCODE }}; " |
| f"git commit -m {_quote_shell(message)}; " |
| f"if ($LASTEXITCODE -ne 0) {{ exit $LASTEXITCODE }}; " |
| f"git push {_quote_shell(remote)} {_quote_shell(branch)}; " |
| "exit $LASTEXITCODE" |
| ) |
| return f"git add -- {add_paths} && git commit -m {_quote_shell(message)} && git push {_quote_shell(remote)} {_quote_shell(branch)}" |
| raise HTTPException(status_code=404, detail=f"Unknown preset: {preset_id}") |
|
|
|
|
| def build_control_router(api_handlers: dict[str, Callable[[dict[str, Any]], dict[str, Any]]]) -> APIRouter: |
| """Create the shared HTML UI + command control router.""" |
|
|
| router = APIRouter() |
| presets = _build_presets(enable_generate="generate" in api_handlers) |
| preset_map = {item.identifier: item for item in presets} |
|
|
| @router.get("/", response_class=HTMLResponse) |
| def index() -> str: |
| return STATIC_INDEX.read_text(encoding="utf-8") |
|
|
| @router.post("/api/login") |
| def login(payload: LoginRequest, response: Response) -> dict[str, Any]: |
| password = _get_password() |
| if not hmac.compare_digest(payload.password, password): |
| raise HTTPException(status_code=401, detail="Invalid password.") |
| token = _encode_cookie_payload({"iat": time.time(), "nonce": secrets.token_hex(8)}) |
| response.set_cookie( |
| SESSION_COOKIE, |
| token, |
| max_age=SESSION_AGE_SECONDS, |
| httponly=True, |
| samesite="lax", |
| ) |
| return {"success": True} |
|
|
| @router.get("/api/commands/presets") |
| def list_presets(_: dict[str, Any] = Depends(_require_session)) -> dict[str, Any]: |
| return {"presets": [preset.to_dict() for preset in presets], "repo_root": str(REPO_ROOT)} |
|
|
| @router.post("/api/commands/run") |
| def run_command(payload: RunCommandRequest, _: dict[str, Any] = Depends(_require_session)) -> dict[str, Any]: |
| if payload.preset_id: |
| preset = preset_map.get(payload.preset_id) |
| if preset is None: |
| raise HTTPException(status_code=404, detail=f"Unknown preset: {payload.preset_id}") |
| _validate_preset_args(preset, payload.args) |
| if preset.mode == "api": |
| handler = api_handlers.get(payload.preset_id) |
| if handler is None: |
| raise HTTPException(status_code=400, detail=f"Preset {payload.preset_id} is not available on this server.") |
| return _api_response(handler, payload.args) |
| command = _build_command_for_preset(payload.preset_id, payload.args) |
| mode = "shell" if isinstance(command, str) else "job" |
| try: |
| job = CONTROL_MANAGER.start_job(preset.label, command, cwd=str(REPO_ROOT), mode=mode) |
| except OSError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
| return {"kind": "job", "job": job.to_dict()} |
| if payload.command: |
| cwd = payload.cwd or str(REPO_ROOT) |
| try: |
| job = CONTROL_MANAGER.start_job("Raw Command", payload.command, cwd=cwd, mode="shell") |
| except OSError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
| return {"kind": "job", "job": job.to_dict()} |
| raise HTTPException(status_code=400, detail="Provide either preset_id or command.") |
|
|
| @router.get("/api/jobs") |
| def list_jobs(_: dict[str, Any] = Depends(_require_session)) -> dict[str, Any]: |
| return {"jobs": CONTROL_MANAGER.list_jobs()} |
|
|
| @router.get("/api/jobs/{job_id}") |
| def get_job(job_id: str, _: dict[str, Any] = Depends(_require_session)) -> dict[str, Any]: |
| try: |
| job = CONTROL_MANAGER.get_job(job_id) |
| except KeyError as exc: |
| raise HTTPException(status_code=404, detail="Job not found.") from exc |
| return {"job": job.to_dict(), "logs": job.logs[-200:]} |
|
|
| @router.get("/api/jobs/{job_id}/stream") |
| def stream_job(job_id: str, _: dict[str, Any] = Depends(_require_session)) -> StreamingResponse: |
| try: |
| job = CONTROL_MANAGER.get_job(job_id) |
| except KeyError as exc: |
| raise HTTPException(status_code=404, detail="Job not found.") from exc |
|
|
| def event_stream() -> Iterator[str]: |
| index = 0 |
| while True: |
| heartbeat = False |
| with job.condition: |
| while index >= len(job.events) and job.status in {"running", "stopping"}: |
| job.condition.wait(timeout=15) |
| if index >= len(job.events) and job.status in {"running", "stopping"}: |
| heartbeat = True |
| break |
| pending = job.events[index:] |
| if heartbeat: |
| yield ": keep-alive\n\n" |
| continue |
| for item in pending: |
| index = item["id"] + 1 |
| payload = json.dumps(item["data"]) |
| yield f"id: {item['id']}\nevent: {item['event']}\ndata: {payload}\n\n" |
| if job.status not in {"running", "stopping"} and index >= len(job.events): |
| break |
|
|
| return StreamingResponse(event_stream(), media_type="text/event-stream") |
|
|
| @router.post("/api/jobs/{job_id}/stop") |
| def stop_job(job_id: str, _: dict[str, Any] = Depends(_require_session)) -> dict[str, Any]: |
| try: |
| job = CONTROL_MANAGER.stop_job(job_id) |
| except KeyError as exc: |
| raise HTTPException(status_code=404, detail="Job not found.") from exc |
| return {"job": job.to_dict()} |
|
|
| return router |
|
|