#!/usr/bin/env python3 """Orchestrate the full headless Search-UI batch pipeline on a RunPod Pod.""" from __future__ import annotations import argparse import json import os import subprocess import sys import time from typing import Any SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) REPO_ROOT = os.path.dirname(SCRIPT_DIR) BACKEND_DIR = os.path.join(REPO_ROOT, "backend") if BACKEND_DIR not in sys.path: sys.path.insert(0, BACKEND_DIR) from runtime_paths import ensure_runtime_dirs from utils import atomic_write_json def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run the Search-UI batch worker on a RunPod Pod.") parser.add_argument("--language", default=os.environ.get("RUNPOD_WORKER_LANGUAGE", "E")) parser.add_argument("--label", default=os.environ.get("RUNPOD_WORKER_LABEL", "720p")) parser.add_argument( "--report-dir", default=os.environ.get( "RUNPOD_WORKER_REPORT_DIR", os.path.join(REPO_ROOT, "docs", "reports", "runpod"), ), ) parser.add_argument( "--download-concurrency", type=int, default=int(os.environ.get("RUNPOD_DOWNLOAD_CONCURRENCY", "4")), ) parser.add_argument( "--process-shards", type=int, default=int(os.environ.get("RUNPOD_PROCESS_SHARDS", "4")), ) parser.add_argument( "--process-limit", type=int, default=int(os.environ.get("RUNPOD_PROCESS_LIMIT", "0")), ) parser.add_argument( "--partition-total", type=int, default=int(os.environ.get("RUNPOD_PARTITION_TOTAL", "1")), help="Split the backlog into N deterministic partitions (default: 1).", ) parser.add_argument( "--partition-index", type=int, default=int(os.environ.get("RUNPOD_PARTITION_INDEX", "0")), help="0-based partition index to process when partition-total > 1.", ) parser.add_argument( "--command-timeout-seconds", type=int, default=int(os.environ.get("RUNPOD_COMMAND_TIMEOUT_SECONDS", "64800")), help="Maximum runtime for a single non-sharded phase (0 disables timeout).", ) parser.add_argument( "--shard-timeout-seconds", type=int, default=int(os.environ.get("RUNPOD_SHARD_TIMEOUT_SECONDS", "64800")), help="Maximum runtime for the full shard-processing phase (0 disables timeout).", ) parser.add_argument( "--skip-catalog-subtitles", action="store_true", help="Skip standalone subtitle download and indexing.", ) parser.add_argument( "--skip-video-download", action="store_true", help="Skip the video download phase.", ) parser.add_argument( "--skip-embedded-subtitles", action="store_true", default=os.environ.get("RUNPOD_SKIP_EMBEDDED_SUBTITLES", "").strip().lower() in {"1", "true", "yes", "on"}, help="Skip embedded subtitle extraction for downloaded MP4s.", ) parser.add_argument( "--skip-video-processing", action="store_true", help="Skip thumbnail/image/face processing.", ) parser.add_argument( "--rebuild-current-subtitle-embeddings", action="store_true", help="Run the current-recipe subtitle embedding rebuild after extraction.", ) parser.add_argument( "--rebuild-video-concepts", action="store_true", help="Run the current-recipe video-concept rebuild after extraction.", ) parser.add_argument( "--delete-video-after-processing", action="store_true", default=os.environ.get("RUNPOD_DELETE_VIDEO_AFTER_PROCESSING", "").strip().lower() in {"1", "true", "yes", "on"}, help="Delete MP4s after image/face processing succeeds.", ) parser.add_argument( "--redownload-metadata", action="store_true", help="Force catalog metadata refresh before download phases.", ) return parser.parse_args() def _save_manifest(path: str, payload: dict[str, Any]) -> None: atomic_write_json(path, payload, indent=2) def _parse_gpu_ids(value: str) -> list[int]: if not value.strip(): visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip() if visible_devices and visible_devices != "-1": parsed_visible: list[int] = [] for raw_part in visible_devices.replace(",", " ").split(): if raw_part.isdigit(): parsed_visible.append(int(raw_part)) if parsed_visible: return parsed_visible try: result = subprocess.run( ["nvidia-smi", "--list-gpus"], capture_output=True, text=True, check=False, ) if result.returncode == 0: count = sum(1 for line in result.stdout.splitlines() if line.startswith("GPU ")) if count > 1: return list(range(count)) except Exception: return [] return [] gpu_ids: list[int] = [] for raw_part in value.replace(",", " ").split(): gpu_ids.append(int(raw_part)) return gpu_ids def _get_shard_gpu_id(shard_index: int, gpu_ids: list[int]) -> int | None: if not gpu_ids: return None return gpu_ids[shard_index % len(gpu_ids)] def _run_command( name: str, cmd: list[str], *, log_path: str, cwd: str, manifest: dict[str, Any], timeout_seconds: int = 0, ) -> None: started_at = time.time() phase = { "name": name, "command": cmd, "log_path": log_path, "started_at": time.strftime("%Y-%m-%dT%H:%M:%S"), } manifest["phases"].append(phase) _save_manifest(manifest["manifest_path"], manifest) with open(log_path, "w", encoding="utf-8") as handle: handle.write("$ " + " ".join(cmd) + "\n\n") handle.flush() try: completed = subprocess.run( cmd, cwd=cwd, stdout=handle, stderr=subprocess.STDOUT, text=True, check=False, timeout=timeout_seconds if timeout_seconds > 0 else None, ) except subprocess.TimeoutExpired: phase["completed_at"] = time.strftime("%Y-%m-%dT%H:%M:%S") phase["duration_seconds"] = round(time.time() - started_at, 1) phase["exit_code"] = None phase["status"] = "error" phase["error"] = f"Timed out after {timeout_seconds}s" _save_manifest(manifest["manifest_path"], manifest) raise RuntimeError(f"{name} timed out after {timeout_seconds}s; see {log_path}") phase["completed_at"] = time.strftime("%Y-%m-%dT%H:%M:%S") phase["duration_seconds"] = round(time.time() - started_at, 1) phase["exit_code"] = completed.returncode phase["status"] = "success" if completed.returncode == 0 else "error" _save_manifest(manifest["manifest_path"], manifest) if completed.returncode != 0: raise RuntimeError(f"{name} failed with exit code {completed.returncode}; see {log_path}") def _run_processing_shards( *, label: str, process_shards: int, process_limit: int, partition_total: int, partition_index: int, delete_after: bool, report_dir: str, manifest: dict[str, Any], shard_timeout_seconds: int = 0, ) -> None: started_at = time.time() phase = { "name": "video-processing", "started_at": time.strftime("%Y-%m-%dT%H:%M:%S"), "status": "running", "shards": [], } manifest["phases"].append(phase) _save_manifest(manifest["manifest_path"], manifest) processes: list[tuple[int, subprocess.Popen[str], Any]] = [] gpu_ids = _parse_gpu_ids(os.environ.get("RUNPOD_WORKER_GPU_IDS", "")) if len(gpu_ids) > 1 and process_shards > len(gpu_ids): raise RuntimeError( f"process_shards={process_shards} exceeds available GPU ids ({gpu_ids}); " "use at most one shard per GPU for a multi-GPU proof" ) for shard_index in range(process_shards): report_path = os.path.join(report_dir, f"worker-shard-{shard_index}.json") log_path = os.path.join(report_dir, f"worker-shard-{shard_index}.log") cmd = [ sys.executable, os.path.join(REPO_ROOT, "scripts", "process-local-videos.py"), "--label", label, "--partition-total", str(partition_total), "--partition-index", str(partition_index), "--shard-count", str(process_shards), "--shard-index", str(shard_index), "--report", report_path, ] gpu_id = _get_shard_gpu_id(shard_index, gpu_ids) if gpu_id is not None: cmd.extend(["--gpu-id", str(gpu_id)]) if process_limit > 0: cmd.extend(["--limit", str(process_limit)]) if delete_after: cmd.append("--delete-after") phase["shards"].append( { "shard_index": shard_index, "command": cmd, "report_path": report_path, "log_path": log_path, "status": "running", } ) handle = open(log_path, "w", encoding="utf-8") handle.write("$ " + " ".join(cmd) + "\n\n") handle.flush() proc = subprocess.Popen( cmd, cwd=REPO_ROOT, stdout=handle, stderr=subprocess.STDOUT, text=True, ) processes.append((shard_index, proc, handle)) deadline = time.time() + shard_timeout_seconds if shard_timeout_seconds > 0 else None shard_errors = 0 completed_shards: set[int] = set() while len(completed_shards) < len(processes): if deadline is not None and time.time() > deadline: for _, proc, _ in processes: if proc.poll() is None: proc.terminate() time.sleep(2) for _, proc, _ in processes: if proc.poll() is None: proc.kill() for shard_index, _, handle in processes: handle.close() if shard_index not in completed_shards: phase["shards"][shard_index]["exit_code"] = None phase["shards"][shard_index]["status"] = "error" phase["shards"][shard_index]["error"] = f"Timed out after {shard_timeout_seconds}s" phase["completed_at"] = time.strftime("%Y-%m-%dT%H:%M:%S") phase["duration_seconds"] = round(time.time() - started_at, 1) phase["status"] = "error" phase["failed_shards"] = len(processes) - len(completed_shards) _save_manifest(manifest["manifest_path"], manifest) raise RuntimeError( f"Video processing shards timed out after {shard_timeout_seconds}s; see {report_dir}" ) for shard_index, proc, handle in processes: if shard_index in completed_shards: continue exit_code = proc.poll() if exit_code is None: continue handle.close() completed_shards.add(shard_index) phase["shards"][shard_index]["exit_code"] = exit_code phase["shards"][shard_index]["status"] = "success" if exit_code == 0 else "error" if exit_code != 0: shard_errors += 1 for other_index, other_proc, other_handle in processes: if other_index in completed_shards: continue if other_proc.poll() is None: other_proc.terminate() time.sleep(2) for other_index, other_proc, other_handle in processes: if other_index in completed_shards: continue if other_proc.poll() is None: other_proc.kill() other_exit_code = other_proc.wait() other_handle.close() completed_shards.add(other_index) phase["shards"][other_index]["exit_code"] = other_exit_code phase["shards"][other_index]["status"] = "success" if other_exit_code == 0 else "error" if other_exit_code != 0: shard_errors += 1 _save_manifest(manifest["manifest_path"], manifest) break _save_manifest(manifest["manifest_path"], manifest) if len(completed_shards) < len(processes): time.sleep(5) phase["completed_at"] = time.strftime("%Y-%m-%dT%H:%M:%S") phase["duration_seconds"] = round(time.time() - started_at, 1) phase["status"] = "success" if shard_errors == 0 else "error" phase["failed_shards"] = shard_errors _save_manifest(manifest["manifest_path"], manifest) if shard_errors: raise RuntimeError(f"Video processing shards failed ({shard_errors}/{process_shards}); see {report_dir}") def main() -> int: args = parse_args() if args.partition_total < 1: raise SystemExit("--partition-total must be >= 1") if args.partition_index < 0 or args.partition_index >= args.partition_total: raise SystemExit("--partition-index must be in [0, partition-total)") ensure_runtime_dirs() os.makedirs(args.report_dir, exist_ok=True) manifest_path = os.path.join(args.report_dir, "runpod-batch-manifest.json") manifest: dict[str, Any] = { "manifest_path": manifest_path, "started_at": time.strftime("%Y-%m-%dT%H:%M:%S"), "cwd": REPO_ROOT, "language": args.language, "label": args.label, "env": { "SEARCH_UI_DATA_ROOT": os.environ.get("SEARCH_UI_DATA_ROOT"), "SEARCH_UI_DB_DIR": os.environ.get("SEARCH_UI_DB_DIR"), "SEARCH_UI_VIDEOS_DIR": os.environ.get("SEARCH_UI_VIDEOS_DIR"), "SEARCH_UI_SUBTITLES_DIR": os.environ.get("SEARCH_UI_SUBTITLES_DIR"), "SEARCH_UI_DEEPFACE_FORCE_CPU": os.environ.get("SEARCH_UI_DEEPFACE_FORCE_CPU"), "HF_HOME": os.environ.get("HF_HOME"), "TRANSFORMERS_CACHE": os.environ.get("TRANSFORMERS_CACHE"), "RUNPOD_PARTITION_TOTAL": str(args.partition_total), "RUNPOD_PARTITION_INDEX": str(args.partition_index), }, "phases": [], } _save_manifest(manifest_path, manifest) try: if not args.skip_catalog_subtitles: _run_command( "catalog-subtitles", [ sys.executable, os.path.join(REPO_ROOT, "scripts", "download-and-index-subtitles.py"), "--language", args.language, "--report", os.path.join(args.report_dir, "catalog-subtitles.json"), *(["--redownload-metadata"] if args.redownload_metadata else []), ], log_path=os.path.join(args.report_dir, "catalog-subtitles.log"), cwd=REPO_ROOT, manifest=manifest, timeout_seconds=args.command_timeout_seconds, ) if not args.skip_video_download: download_cmd = [ sys.executable, os.path.join(REPO_ROOT, "scripts", "download-all-videos.py"), "--language", args.language, "--label", args.label, "--concurrency", str(args.download_concurrency), "--partition-total", str(args.partition_total), "--partition-index", str(args.partition_index), "--report", os.path.join(args.report_dir, "video-download.json"), *(["--redownload-metadata"] if args.redownload_metadata else []), ] if args.process_limit > 0: download_cmd.extend(["--limit", str(args.process_limit)]) _run_command( "video-download", download_cmd, log_path=os.path.join(args.report_dir, "video-download.log"), cwd=REPO_ROOT, manifest=manifest, timeout_seconds=args.command_timeout_seconds, ) if not args.skip_embedded_subtitles: _run_command( "embedded-subtitles", [ sys.executable, os.path.join(REPO_ROOT, "scripts", "backfill-embedded-subtitles.py"), "--report", os.path.join(args.report_dir, "embedded-subtitles.json"), ], log_path=os.path.join(args.report_dir, "embedded-subtitles.log"), cwd=REPO_ROOT, manifest=manifest, timeout_seconds=args.command_timeout_seconds, ) if args.rebuild_current_subtitle_embeddings: _run_command( "rebuild-subtitle-embeddings", [ sys.executable, os.path.join(REPO_ROOT, "scripts", "rebuild-subtitle-embeddings.py"), "--language", args.language, "--batch-size", "50", ], log_path=os.path.join(args.report_dir, "rebuild-subtitle-embeddings.log"), cwd=REPO_ROOT, manifest=manifest, timeout_seconds=args.command_timeout_seconds, ) if args.rebuild_video_concepts: _run_command( "rebuild-video-concepts", [ sys.executable, os.path.join(REPO_ROOT, "scripts", "rebuild-video-concepts.py"), "--language", args.language, "--batch-size", "50", ], log_path=os.path.join(args.report_dir, "rebuild-video-concepts.log"), cwd=REPO_ROOT, manifest=manifest, timeout_seconds=args.command_timeout_seconds, ) if not args.skip_video_processing: _run_processing_shards( label=args.label, process_shards=args.process_shards, process_limit=args.process_limit, partition_total=args.partition_total, partition_index=args.partition_index, delete_after=args.delete_video_after_processing, report_dir=args.report_dir, manifest=manifest, shard_timeout_seconds=args.shard_timeout_seconds, ) _run_command( "reprocess-backlog-report", [ sys.executable, os.path.join(REPO_ROOT, "scripts", "report-reprocess-backlog.py"), "--output", os.path.join(args.report_dir, "reprocess-backlog.json"), ], log_path=os.path.join(args.report_dir, "reprocess-backlog.log"), cwd=REPO_ROOT, manifest=manifest, timeout_seconds=args.command_timeout_seconds, ) except Exception as exc: manifest["status"] = "error" manifest["error"] = str(exc) manifest["completed_at"] = time.strftime("%Y-%m-%dT%H:%M:%S") _save_manifest(manifest_path, manifest) print(str(exc)) return 1 manifest["status"] = "success" manifest["completed_at"] = time.strftime("%Y-%m-%dT%H:%M:%S") _save_manifest(manifest_path, manifest) print(json.dumps({"status": "success", "manifest_path": manifest_path}, indent=2)) return 0 if __name__ == "__main__": raise SystemExit(main())