jw-search / scripts /runpod-batch-worker.py
jw-tools's picture
deploy: latest main (lazy-ML cold start, durable launcher, web-image search, scene search) + full-app data refresh
7ea1851 verified
#!/usr/bin/env python3
"""Orchestrate the full headless Search-UI batch pipeline on a RunPod Pod."""
from __future__ import annotations
import argparse
import json
import os
import subprocess
import sys
import time
from typing import Any
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
REPO_ROOT = os.path.dirname(SCRIPT_DIR)
BACKEND_DIR = os.path.join(REPO_ROOT, "backend")
if BACKEND_DIR not in sys.path:
sys.path.insert(0, BACKEND_DIR)
from runtime_paths import ensure_runtime_dirs
from utils import atomic_write_json
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run the Search-UI batch worker on a RunPod Pod.")
parser.add_argument("--language", default=os.environ.get("RUNPOD_WORKER_LANGUAGE", "E"))
parser.add_argument("--label", default=os.environ.get("RUNPOD_WORKER_LABEL", "720p"))
parser.add_argument(
"--report-dir",
default=os.environ.get(
"RUNPOD_WORKER_REPORT_DIR",
os.path.join(REPO_ROOT, "docs", "reports", "runpod"),
),
)
parser.add_argument(
"--download-concurrency",
type=int,
default=int(os.environ.get("RUNPOD_DOWNLOAD_CONCURRENCY", "4")),
)
parser.add_argument(
"--process-shards",
type=int,
default=int(os.environ.get("RUNPOD_PROCESS_SHARDS", "4")),
)
parser.add_argument(
"--process-limit",
type=int,
default=int(os.environ.get("RUNPOD_PROCESS_LIMIT", "0")),
)
parser.add_argument(
"--partition-total",
type=int,
default=int(os.environ.get("RUNPOD_PARTITION_TOTAL", "1")),
help="Split the backlog into N deterministic partitions (default: 1).",
)
parser.add_argument(
"--partition-index",
type=int,
default=int(os.environ.get("RUNPOD_PARTITION_INDEX", "0")),
help="0-based partition index to process when partition-total > 1.",
)
parser.add_argument(
"--command-timeout-seconds",
type=int,
default=int(os.environ.get("RUNPOD_COMMAND_TIMEOUT_SECONDS", "64800")),
help="Maximum runtime for a single non-sharded phase (0 disables timeout).",
)
parser.add_argument(
"--shard-timeout-seconds",
type=int,
default=int(os.environ.get("RUNPOD_SHARD_TIMEOUT_SECONDS", "64800")),
help="Maximum runtime for the full shard-processing phase (0 disables timeout).",
)
parser.add_argument(
"--skip-catalog-subtitles",
action="store_true",
help="Skip standalone subtitle download and indexing.",
)
parser.add_argument(
"--skip-video-download",
action="store_true",
help="Skip the video download phase.",
)
parser.add_argument(
"--skip-embedded-subtitles",
action="store_true",
default=os.environ.get("RUNPOD_SKIP_EMBEDDED_SUBTITLES", "").strip().lower() in {"1", "true", "yes", "on"},
help="Skip embedded subtitle extraction for downloaded MP4s.",
)
parser.add_argument(
"--skip-video-processing",
action="store_true",
help="Skip thumbnail/image/face processing.",
)
parser.add_argument(
"--rebuild-current-subtitle-embeddings",
action="store_true",
help="Run the current-recipe subtitle embedding rebuild after extraction.",
)
parser.add_argument(
"--rebuild-video-concepts",
action="store_true",
help="Run the current-recipe video-concept rebuild after extraction.",
)
parser.add_argument(
"--delete-video-after-processing",
action="store_true",
default=os.environ.get("RUNPOD_DELETE_VIDEO_AFTER_PROCESSING", "").strip().lower() in {"1", "true", "yes", "on"},
help="Delete MP4s after image/face processing succeeds.",
)
parser.add_argument(
"--redownload-metadata",
action="store_true",
help="Force catalog metadata refresh before download phases.",
)
return parser.parse_args()
def _save_manifest(path: str, payload: dict[str, Any]) -> None:
atomic_write_json(path, payload, indent=2)
def _parse_gpu_ids(value: str) -> list[int]:
if not value.strip():
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
if visible_devices and visible_devices != "-1":
parsed_visible: list[int] = []
for raw_part in visible_devices.replace(",", " ").split():
if raw_part.isdigit():
parsed_visible.append(int(raw_part))
if parsed_visible:
return parsed_visible
try:
result = subprocess.run(
["nvidia-smi", "--list-gpus"],
capture_output=True,
text=True,
check=False,
)
if result.returncode == 0:
count = sum(1 for line in result.stdout.splitlines() if line.startswith("GPU "))
if count > 1:
return list(range(count))
except Exception:
return []
return []
gpu_ids: list[int] = []
for raw_part in value.replace(",", " ").split():
gpu_ids.append(int(raw_part))
return gpu_ids
def _get_shard_gpu_id(shard_index: int, gpu_ids: list[int]) -> int | None:
if not gpu_ids:
return None
return gpu_ids[shard_index % len(gpu_ids)]
def _run_command(
name: str,
cmd: list[str],
*,
log_path: str,
cwd: str,
manifest: dict[str, Any],
timeout_seconds: int = 0,
) -> None:
started_at = time.time()
phase = {
"name": name,
"command": cmd,
"log_path": log_path,
"started_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
manifest["phases"].append(phase)
_save_manifest(manifest["manifest_path"], manifest)
with open(log_path, "w", encoding="utf-8") as handle:
handle.write("$ " + " ".join(cmd) + "\n\n")
handle.flush()
try:
completed = subprocess.run(
cmd,
cwd=cwd,
stdout=handle,
stderr=subprocess.STDOUT,
text=True,
check=False,
timeout=timeout_seconds if timeout_seconds > 0 else None,
)
except subprocess.TimeoutExpired:
phase["completed_at"] = time.strftime("%Y-%m-%dT%H:%M:%S")
phase["duration_seconds"] = round(time.time() - started_at, 1)
phase["exit_code"] = None
phase["status"] = "error"
phase["error"] = f"Timed out after {timeout_seconds}s"
_save_manifest(manifest["manifest_path"], manifest)
raise RuntimeError(f"{name} timed out after {timeout_seconds}s; see {log_path}")
phase["completed_at"] = time.strftime("%Y-%m-%dT%H:%M:%S")
phase["duration_seconds"] = round(time.time() - started_at, 1)
phase["exit_code"] = completed.returncode
phase["status"] = "success" if completed.returncode == 0 else "error"
_save_manifest(manifest["manifest_path"], manifest)
if completed.returncode != 0:
raise RuntimeError(f"{name} failed with exit code {completed.returncode}; see {log_path}")
def _run_processing_shards(
*,
label: str,
process_shards: int,
process_limit: int,
partition_total: int,
partition_index: int,
delete_after: bool,
report_dir: str,
manifest: dict[str, Any],
shard_timeout_seconds: int = 0,
) -> None:
started_at = time.time()
phase = {
"name": "video-processing",
"started_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
"status": "running",
"shards": [],
}
manifest["phases"].append(phase)
_save_manifest(manifest["manifest_path"], manifest)
processes: list[tuple[int, subprocess.Popen[str], Any]] = []
gpu_ids = _parse_gpu_ids(os.environ.get("RUNPOD_WORKER_GPU_IDS", ""))
if len(gpu_ids) > 1 and process_shards > len(gpu_ids):
raise RuntimeError(
f"process_shards={process_shards} exceeds available GPU ids ({gpu_ids}); "
"use at most one shard per GPU for a multi-GPU proof"
)
for shard_index in range(process_shards):
report_path = os.path.join(report_dir, f"worker-shard-{shard_index}.json")
log_path = os.path.join(report_dir, f"worker-shard-{shard_index}.log")
cmd = [
sys.executable,
os.path.join(REPO_ROOT, "scripts", "process-local-videos.py"),
"--label",
label,
"--partition-total",
str(partition_total),
"--partition-index",
str(partition_index),
"--shard-count",
str(process_shards),
"--shard-index",
str(shard_index),
"--report",
report_path,
]
gpu_id = _get_shard_gpu_id(shard_index, gpu_ids)
if gpu_id is not None:
cmd.extend(["--gpu-id", str(gpu_id)])
if process_limit > 0:
cmd.extend(["--limit", str(process_limit)])
if delete_after:
cmd.append("--delete-after")
phase["shards"].append(
{
"shard_index": shard_index,
"command": cmd,
"report_path": report_path,
"log_path": log_path,
"status": "running",
}
)
handle = open(log_path, "w", encoding="utf-8")
handle.write("$ " + " ".join(cmd) + "\n\n")
handle.flush()
proc = subprocess.Popen(
cmd,
cwd=REPO_ROOT,
stdout=handle,
stderr=subprocess.STDOUT,
text=True,
)
processes.append((shard_index, proc, handle))
deadline = time.time() + shard_timeout_seconds if shard_timeout_seconds > 0 else None
shard_errors = 0
completed_shards: set[int] = set()
while len(completed_shards) < len(processes):
if deadline is not None and time.time() > deadline:
for _, proc, _ in processes:
if proc.poll() is None:
proc.terminate()
time.sleep(2)
for _, proc, _ in processes:
if proc.poll() is None:
proc.kill()
for shard_index, _, handle in processes:
handle.close()
if shard_index not in completed_shards:
phase["shards"][shard_index]["exit_code"] = None
phase["shards"][shard_index]["status"] = "error"
phase["shards"][shard_index]["error"] = f"Timed out after {shard_timeout_seconds}s"
phase["completed_at"] = time.strftime("%Y-%m-%dT%H:%M:%S")
phase["duration_seconds"] = round(time.time() - started_at, 1)
phase["status"] = "error"
phase["failed_shards"] = len(processes) - len(completed_shards)
_save_manifest(manifest["manifest_path"], manifest)
raise RuntimeError(
f"Video processing shards timed out after {shard_timeout_seconds}s; see {report_dir}"
)
for shard_index, proc, handle in processes:
if shard_index in completed_shards:
continue
exit_code = proc.poll()
if exit_code is None:
continue
handle.close()
completed_shards.add(shard_index)
phase["shards"][shard_index]["exit_code"] = exit_code
phase["shards"][shard_index]["status"] = "success" if exit_code == 0 else "error"
if exit_code != 0:
shard_errors += 1
for other_index, other_proc, other_handle in processes:
if other_index in completed_shards:
continue
if other_proc.poll() is None:
other_proc.terminate()
time.sleep(2)
for other_index, other_proc, other_handle in processes:
if other_index in completed_shards:
continue
if other_proc.poll() is None:
other_proc.kill()
other_exit_code = other_proc.wait()
other_handle.close()
completed_shards.add(other_index)
phase["shards"][other_index]["exit_code"] = other_exit_code
phase["shards"][other_index]["status"] = "success" if other_exit_code == 0 else "error"
if other_exit_code != 0:
shard_errors += 1
_save_manifest(manifest["manifest_path"], manifest)
break
_save_manifest(manifest["manifest_path"], manifest)
if len(completed_shards) < len(processes):
time.sleep(5)
phase["completed_at"] = time.strftime("%Y-%m-%dT%H:%M:%S")
phase["duration_seconds"] = round(time.time() - started_at, 1)
phase["status"] = "success" if shard_errors == 0 else "error"
phase["failed_shards"] = shard_errors
_save_manifest(manifest["manifest_path"], manifest)
if shard_errors:
raise RuntimeError(f"Video processing shards failed ({shard_errors}/{process_shards}); see {report_dir}")
def main() -> int:
args = parse_args()
if args.partition_total < 1:
raise SystemExit("--partition-total must be >= 1")
if args.partition_index < 0 or args.partition_index >= args.partition_total:
raise SystemExit("--partition-index must be in [0, partition-total)")
ensure_runtime_dirs()
os.makedirs(args.report_dir, exist_ok=True)
manifest_path = os.path.join(args.report_dir, "runpod-batch-manifest.json")
manifest: dict[str, Any] = {
"manifest_path": manifest_path,
"started_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
"cwd": REPO_ROOT,
"language": args.language,
"label": args.label,
"env": {
"SEARCH_UI_DATA_ROOT": os.environ.get("SEARCH_UI_DATA_ROOT"),
"SEARCH_UI_DB_DIR": os.environ.get("SEARCH_UI_DB_DIR"),
"SEARCH_UI_VIDEOS_DIR": os.environ.get("SEARCH_UI_VIDEOS_DIR"),
"SEARCH_UI_SUBTITLES_DIR": os.environ.get("SEARCH_UI_SUBTITLES_DIR"),
"SEARCH_UI_DEEPFACE_FORCE_CPU": os.environ.get("SEARCH_UI_DEEPFACE_FORCE_CPU"),
"HF_HOME": os.environ.get("HF_HOME"),
"TRANSFORMERS_CACHE": os.environ.get("TRANSFORMERS_CACHE"),
"RUNPOD_PARTITION_TOTAL": str(args.partition_total),
"RUNPOD_PARTITION_INDEX": str(args.partition_index),
},
"phases": [],
}
_save_manifest(manifest_path, manifest)
try:
if not args.skip_catalog_subtitles:
_run_command(
"catalog-subtitles",
[
sys.executable,
os.path.join(REPO_ROOT, "scripts", "download-and-index-subtitles.py"),
"--language",
args.language,
"--report",
os.path.join(args.report_dir, "catalog-subtitles.json"),
*(["--redownload-metadata"] if args.redownload_metadata else []),
],
log_path=os.path.join(args.report_dir, "catalog-subtitles.log"),
cwd=REPO_ROOT,
manifest=manifest,
timeout_seconds=args.command_timeout_seconds,
)
if not args.skip_video_download:
download_cmd = [
sys.executable,
os.path.join(REPO_ROOT, "scripts", "download-all-videos.py"),
"--language",
args.language,
"--label",
args.label,
"--concurrency",
str(args.download_concurrency),
"--partition-total",
str(args.partition_total),
"--partition-index",
str(args.partition_index),
"--report",
os.path.join(args.report_dir, "video-download.json"),
*(["--redownload-metadata"] if args.redownload_metadata else []),
]
if args.process_limit > 0:
download_cmd.extend(["--limit", str(args.process_limit)])
_run_command(
"video-download",
download_cmd,
log_path=os.path.join(args.report_dir, "video-download.log"),
cwd=REPO_ROOT,
manifest=manifest,
timeout_seconds=args.command_timeout_seconds,
)
if not args.skip_embedded_subtitles:
_run_command(
"embedded-subtitles",
[
sys.executable,
os.path.join(REPO_ROOT, "scripts", "backfill-embedded-subtitles.py"),
"--report",
os.path.join(args.report_dir, "embedded-subtitles.json"),
],
log_path=os.path.join(args.report_dir, "embedded-subtitles.log"),
cwd=REPO_ROOT,
manifest=manifest,
timeout_seconds=args.command_timeout_seconds,
)
if args.rebuild_current_subtitle_embeddings:
_run_command(
"rebuild-subtitle-embeddings",
[
sys.executable,
os.path.join(REPO_ROOT, "scripts", "rebuild-subtitle-embeddings.py"),
"--language",
args.language,
"--batch-size",
"50",
],
log_path=os.path.join(args.report_dir, "rebuild-subtitle-embeddings.log"),
cwd=REPO_ROOT,
manifest=manifest,
timeout_seconds=args.command_timeout_seconds,
)
if args.rebuild_video_concepts:
_run_command(
"rebuild-video-concepts",
[
sys.executable,
os.path.join(REPO_ROOT, "scripts", "rebuild-video-concepts.py"),
"--language",
args.language,
"--batch-size",
"50",
],
log_path=os.path.join(args.report_dir, "rebuild-video-concepts.log"),
cwd=REPO_ROOT,
manifest=manifest,
timeout_seconds=args.command_timeout_seconds,
)
if not args.skip_video_processing:
_run_processing_shards(
label=args.label,
process_shards=args.process_shards,
process_limit=args.process_limit,
partition_total=args.partition_total,
partition_index=args.partition_index,
delete_after=args.delete_video_after_processing,
report_dir=args.report_dir,
manifest=manifest,
shard_timeout_seconds=args.shard_timeout_seconds,
)
_run_command(
"reprocess-backlog-report",
[
sys.executable,
os.path.join(REPO_ROOT, "scripts", "report-reprocess-backlog.py"),
"--output",
os.path.join(args.report_dir, "reprocess-backlog.json"),
],
log_path=os.path.join(args.report_dir, "reprocess-backlog.log"),
cwd=REPO_ROOT,
manifest=manifest,
timeout_seconds=args.command_timeout_seconds,
)
except Exception as exc:
manifest["status"] = "error"
manifest["error"] = str(exc)
manifest["completed_at"] = time.strftime("%Y-%m-%dT%H:%M:%S")
_save_manifest(manifest_path, manifest)
print(str(exc))
return 1
manifest["status"] = "success"
manifest["completed_at"] = time.strftime("%Y-%m-%dT%H:%M:%S")
_save_manifest(manifest_path, manifest)
print(json.dumps({"status": "success", "manifest_path": manifest_path}, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())