| |
| """Prepare JW Search data for a Hugging Face Space at container startup.""" |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import json |
| import os |
| from pathlib import Path |
| import shutil |
| import tarfile |
| from urllib.error import URLError |
| from urllib.request import urlopen, urlretrieve |
| import zipfile |
|
|
|
|
| LOGGER = logging.getLogger("prepare-runtime-data") |
| DEFAULT_DATA_ROOT = "/data/search-ui" |
| URL_TIMEOUT_SECONDS = 300 |
| READY_MARKERS = ("database.db", "json") |
| READY_MARKER_FILE = ".search-ui-data-ready" |
| DATA_DIRECTORIES = ( |
| "json", |
| "subtitles", |
| "videos", |
| "publications", |
| "transcriptions", |
| ) |
|
|
|
|
| def configure_logging() -> None: |
| """Configure startup logging.""" |
| logging.basicConfig(level=logging.INFO, format="%(message)s") |
|
|
|
|
| def env_truthy(name: str) -> bool: |
| """Return whether an environment variable is truthy.""" |
| return os.getenv(name, "").strip().lower() in {"1", "true", "yes", "on"} |
|
|
|
|
| def env_falsey(name: str) -> bool: |
| """Return whether an environment variable is explicitly false.""" |
| return os.getenv(name, "").strip().lower() in {"0", "false", "no", "off"} |
|
|
|
|
| def has_runtime_data(data_root: Path) -> bool: |
| """Return whether the expected data bundle markers already exist.""" |
| return any((data_root / marker).exists() for marker in READY_MARKERS) |
|
|
|
|
| def has_completed_runtime_data(data_root: Path) -> bool: |
| """Return whether a bundle extraction completed successfully.""" |
| return (data_root / READY_MARKER_FILE).exists() |
|
|
|
|
| def mark_runtime_data_ready(data_root: Path) -> None: |
| """Write the marker used to distinguish complete data from partial data.""" |
| (data_root / READY_MARKER_FILE).write_text("ready\n", encoding="utf-8") |
|
|
|
|
| def clear_runtime_data(data_root: Path) -> None: |
| """Clear runtime data while preserving download caches.""" |
| for child in data_root.iterdir(): |
| if child.name == ".bundle-cache": |
| continue |
| if child.is_dir() and not child.is_symlink(): |
| shutil.rmtree(child) |
| else: |
| child.unlink() |
|
|
|
|
| def iter_database_files(source_dir: Path) -> list[Path]: |
| """Return root-level SQLite database files from a data source.""" |
| return sorted(source_dir.glob("*.db")) |
|
|
|
|
| def link_or_copy(source: Path, target: Path, *, symlink: bool) -> None: |
| """Place one source path at target by symlink or copy.""" |
| if target.exists() or target.is_symlink(): |
| return |
| target.parent.mkdir(parents=True, exist_ok=True) |
| if symlink: |
| target.symlink_to(source) |
| elif source.is_dir(): |
| shutil.copytree(source, target) |
| else: |
| shutil.copy2(source, target) |
|
|
|
|
| def prepare_from_source_dir(source_dir: Path, data_root: Path) -> None: |
| """Prepare runtime data from an already-unpacked mounted directory.""" |
| if not source_dir.exists(): |
| raise RuntimeError(f"Configured data source does not exist: {source_dir}") |
| if not source_dir.is_dir(): |
| raise RuntimeError(f"Configured data source is not a directory: {source_dir}") |
|
|
| LOGGER.info("Preparing runtime data from mounted source: %s", source_dir) |
| for database_path in iter_database_files(source_dir): |
| target_path = data_root / database_path.name |
| if not target_path.exists(): |
| LOGGER.info("Copying database %s", database_path.name) |
| shutil.copy2(database_path, target_path) |
|
|
| for directory_name in DATA_DIRECTORIES: |
| source_path = source_dir / directory_name |
| if not source_path.exists(): |
| continue |
| target_path = data_root / directory_name |
| if target_path.exists() or target_path.is_symlink(): |
| continue |
| LOGGER.info("Linking data directory %s", directory_name) |
| target_path.symlink_to(source_path) |
|
|
|
|
| def download_url(url: str, cache_dir: Path) -> Path: |
| """Download a data bundle URL into the runtime cache.""" |
| cache_dir.mkdir(parents=True, exist_ok=True) |
| filename = Path(url.split("?", 1)[0]).name or "search-ui-data.tar" |
| target = cache_dir / filename |
| temp_target = target.with_suffix(target.suffix + ".tmp") |
|
|
| if target.exists() and target.stat().st_size > 0: |
| LOGGER.info("Using cached data bundle: %s", target) |
| return target |
|
|
| LOGGER.info("Downloading data bundle from %s", url) |
| try: |
| urlretrieve(url, temp_target) |
| except (OSError, URLError) as exc: |
| raise RuntimeError(f"Unable to download data bundle from {url}: {exc}") from exc |
|
|
| temp_target.replace(target) |
| return target |
|
|
|
|
| def download_hub_file(cache_dir: Path) -> Path | None: |
| """Download a data bundle from a Hugging Face repo when configured.""" |
| repo_id = os.getenv("SEARCH_UI_DATA_BUNDLE_REPO_ID", "").strip() |
| filename = os.getenv("SEARCH_UI_DATA_BUNDLE_FILENAME", "").strip() |
| if not repo_id and not filename: |
| return None |
| if not repo_id or not filename: |
| raise RuntimeError( |
| "Set both SEARCH_UI_DATA_BUNDLE_REPO_ID and " |
| "SEARCH_UI_DATA_BUNDLE_FILENAME, or neither." |
| ) |
|
|
| from huggingface_hub import hf_hub_download |
|
|
| repo_type = os.getenv("SEARCH_UI_DATA_BUNDLE_REPO_TYPE", "dataset").strip() or "dataset" |
| revision = os.getenv("SEARCH_UI_DATA_BUNDLE_REVISION", "").strip() or None |
| LOGGER.info("Downloading data bundle from Hugging Face %s repo %s", repo_type, repo_id) |
| return Path( |
| hf_hub_download( |
| repo_id=repo_id, |
| filename=filename, |
| repo_type=repo_type, |
| revision=revision, |
| local_dir=str(cache_dir / "hub"), |
| ) |
| ) |
|
|
|
|
| def resolve_bundle(cache_dir: Path) -> Path | None: |
| """Resolve the configured data bundle source.""" |
| local_path = os.getenv("SEARCH_UI_DATA_BUNDLE_PATH", "").strip() |
| if local_path: |
| bundle_path = Path(local_path).expanduser().resolve() |
| if not bundle_path.exists(): |
| raise RuntimeError(f"Configured data bundle does not exist: {bundle_path}") |
| return bundle_path |
|
|
| hub_path = download_hub_file(cache_dir) |
| if hub_path is not None: |
| return hub_path |
|
|
| url = os.getenv("SEARCH_UI_DATA_BUNDLE_URL", "").strip() |
| if url: |
| return download_url(url, cache_dir) |
|
|
| return None |
|
|
|
|
| def validate_member_path(target_dir: Path, member_name: str) -> None: |
| """Reject archive members that would write outside the data root.""" |
| target_dir = target_dir.resolve() |
| destination = (target_dir / member_name).resolve() |
| if destination == target_dir: |
| return |
| if not str(destination).startswith(str(target_dir) + os.sep): |
| raise RuntimeError(f"Unsafe data bundle path rejected: {member_name}") |
|
|
|
|
| def is_ignorable_archive_member(member_name: str) -> bool: |
| """Return whether a bundle member is safe to skip.""" |
| parts = [part for part in member_name.split("/") if part and part != "."] |
| return any(part == "__MACOSX" or part.startswith("._") for part in parts) |
|
|
|
|
| def extract_tar_member(archive: tarfile.TarFile, member: tarfile.TarInfo, target_dir: Path) -> None: |
| """Safely extract one tar member into the data root.""" |
| if is_ignorable_archive_member(member.name): |
| return |
|
|
| validate_member_path(target_dir, member.name) |
| if member.issym() or member.islnk(): |
| raise RuntimeError(f"Unsafe tar link rejected: {member.name}") |
|
|
| destination = (target_dir / member.name).resolve() |
| if member.isdir(): |
| destination.mkdir(parents=True, exist_ok=True) |
| return |
| if not member.isfile(): |
| raise RuntimeError(f"Unsupported tar member rejected: {member.name}") |
|
|
| source = archive.extractfile(member) |
| if source is None: |
| raise RuntimeError(f"Unable to read tar member: {member.name}") |
|
|
| destination.parent.mkdir(parents=True, exist_ok=True) |
| with source, destination.open("wb") as target: |
| shutil.copyfileobj(source, target, length=1024 * 1024) |
|
|
|
|
| def extract_tar(archive_path: Path, target_dir: Path) -> None: |
| """Safely extract a tar archive into the data root.""" |
| with tarfile.open(archive_path) as archive: |
| for member in archive: |
| extract_tar_member(archive, member, target_dir) |
|
|
|
|
| def extract_tar_stream(fileobj, target_dir: Path) -> None: |
| """Safely stream-extract a tar archive into the data root.""" |
| with tarfile.open(fileobj=fileobj, mode="r|*") as archive: |
| for member in archive: |
| extract_tar_member(archive, member, target_dir) |
|
|
|
|
| def read_json_url(url: str) -> dict: |
| """Read a JSON document from a URL.""" |
| try: |
| with urlopen(url, timeout=URL_TIMEOUT_SECONDS) as response: |
| return json.loads(response.read().decode("utf-8")) |
| except (OSError, URLError, json.JSONDecodeError) as exc: |
| raise RuntimeError(f"Unable to read data bundle manifest from {url}: {exc}") from exc |
|
|
|
|
| def resolve_part_urls(manifest: dict) -> list[str]: |
| """Resolve data bundle part URLs from a manifest document.""" |
| raw_parts = manifest.get("parts") |
| if not isinstance(raw_parts, list) or not raw_parts: |
| raise RuntimeError("Data bundle parts manifest must include a non-empty parts list.") |
|
|
| urls: list[str] = [] |
| for index, item in enumerate(raw_parts, start=1): |
| if isinstance(item, str): |
| url = item |
| elif isinstance(item, dict) and isinstance(item.get("url"), str): |
| url = item["url"] |
| else: |
| raise RuntimeError(f"Invalid data bundle part at index {index}.") |
|
|
| if not url.startswith(("http://", "https://")): |
| raise RuntimeError(f"Data bundle part URL must be HTTP(S): {url}") |
| urls.append(url) |
| return urls |
|
|
|
|
| def download_part(url: str, target_path: Path, *, index: int, total: int) -> Path: |
| """Download one bundle part to a temporary local file before reading it.""" |
| target_path.parent.mkdir(parents=True, exist_ok=True) |
| temp_path = target_path.with_suffix(target_path.suffix + ".tmp") |
| last_error: Exception | None = None |
|
|
| for attempt in range(1, 4): |
| if temp_path.exists(): |
| temp_path.unlink() |
| try: |
| LOGGER.info("Downloading data bundle part %s/%s", index, total) |
| with urlopen(url, timeout=URL_TIMEOUT_SECONDS) as response, temp_path.open("wb") as output: |
| shutil.copyfileobj(response, output, length=1024 * 1024) |
|
|
| expected_size = response.headers.get("Content-Length") |
| if expected_size and temp_path.stat().st_size != int(expected_size): |
| raise RuntimeError( |
| f"part size mismatch: expected {expected_size}, got {temp_path.stat().st_size}" |
| ) |
|
|
| temp_path.replace(target_path) |
| return target_path |
| except (OSError, URLError, RuntimeError) as exc: |
| last_error = exc |
| LOGGER.warning( |
| "Failed to download data bundle part %s/%s on attempt %s: %s", |
| index, |
| total, |
| attempt, |
| exc, |
| ) |
|
|
| raise RuntimeError(f"Unable to download data bundle part {index}/{total}: {last_error}") |
|
|
|
|
| class SequentialPartReader: |
| """File-like reader that concatenates downloaded bundle part files.""" |
|
|
| def __init__(self, part_urls: list[str], cache_dir: Path): |
| self.part_urls = part_urls |
| self.cache_dir = cache_dir |
| self.total = len(part_urls) |
| self.index = 0 |
| self.current_file = None |
| self.current_path: Path | None = None |
|
|
| def close(self) -> None: |
| if self.current_file is not None: |
| self.current_file.close() |
| self.current_file = None |
| if self.current_path is not None and self.current_path.exists(): |
| self.current_path.unlink() |
| self.current_path = None |
|
|
| def _open_next_part(self) -> bool: |
| self.close() |
| if self.index >= self.total: |
| return False |
|
|
| self.index += 1 |
| target_path = self.cache_dir / f"bundle-part-{self.index:04d}" |
| self.current_path = download_part( |
| self.part_urls[self.index - 1], |
| target_path, |
| index=self.index, |
| total=self.total, |
| ) |
| self.current_file = self.current_path.open("rb") |
| return True |
|
|
| def read(self, size: int = -1) -> bytes: |
| if size is None or size < 0: |
| size = 1024 * 1024 |
|
|
| chunks: list[bytes] = [] |
| remaining = size |
| while remaining > 0: |
| if self.current_file is None and not self._open_next_part(): |
| break |
|
|
| chunk = self.current_file.read(remaining) |
| if chunk: |
| chunks.append(chunk) |
| remaining -= len(chunk) |
| continue |
|
|
| self.close() |
|
|
| return b"".join(chunks) |
|
|
|
|
| def stream_extract_url(url: str, target_dir: Path) -> None: |
| """Stream a remote tar bundle directly into the data root.""" |
| LOGGER.info("Streaming data bundle from %s", url) |
| try: |
| with urlopen(url, timeout=URL_TIMEOUT_SECONDS) as response: |
| extract_tar_stream(response, target_dir) |
| except (OSError, URLError, tarfile.TarError) as exc: |
| raise RuntimeError(f"Unable to stream data bundle from {url}: {exc}") from exc |
|
|
|
|
| def stream_extract_parts_manifest(manifest_url: str, target_dir: Path, cache_dir: Path) -> None: |
| """Stream-extract a split tar bundle described by a manifest URL.""" |
| LOGGER.info("Streaming split data bundle from manifest %s", manifest_url) |
| manifest = read_json_url(manifest_url) |
| part_urls = resolve_part_urls(manifest) |
| reader = SequentialPartReader(part_urls, cache_dir / "parts") |
| try: |
| extract_tar_stream(reader, target_dir) |
| except (OSError, URLError, tarfile.TarError) as exc: |
| raise RuntimeError(f"Unable to stream split data bundle: {exc}") from exc |
| finally: |
| reader.close() |
|
|
|
|
| def extract_zip(archive_path: Path, target_dir: Path) -> None: |
| """Safely extract a zip archive into the data root.""" |
| with zipfile.ZipFile(archive_path) as archive: |
| for member_name in archive.namelist(): |
| validate_member_path(target_dir, member_name) |
| archive.extractall(target_dir) |
|
|
|
|
| def extract_bundle(archive_path: Path, target_dir: Path) -> None: |
| """Extract a supported data bundle archive.""" |
| LOGGER.info("Extracting data bundle %s into %s", archive_path, target_dir) |
| if zipfile.is_zipfile(archive_path): |
| extract_zip(archive_path, target_dir) |
| return |
| if tarfile.is_tarfile(archive_path): |
| extract_tar(archive_path, target_dir) |
| return |
| raise RuntimeError(f"Unsupported data bundle format: {archive_path}") |
|
|
|
|
| def main() -> int: |
| """Prepare runtime data and return a process exit code.""" |
| configure_logging() |
| data_root = Path(os.getenv("SEARCH_UI_DATA_ROOT", DEFAULT_DATA_ROOT)).resolve() |
| cache_dir = data_root / ".bundle-cache" |
| force = env_truthy("SEARCH_UI_FORCE_DATA_BUNDLE") |
|
|
| data_root.mkdir(parents=True, exist_ok=True) |
| source_dir = os.getenv("SEARCH_UI_DATA_SOURCE_DIR", "").strip() |
| if source_dir: |
| if has_runtime_data(data_root) and not force: |
| LOGGER.info("Using existing runtime data in %s", data_root) |
| else: |
| if force: |
| LOGGER.info("Clearing existing runtime data before source preparation.") |
| for child in data_root.iterdir(): |
| if child.name == ".bundle-cache": |
| continue |
| if child.is_dir() and not child.is_symlink(): |
| shutil.rmtree(child) |
| else: |
| child.unlink() |
| prepare_from_source_dir(Path(source_dir).expanduser().resolve(), data_root) |
| return 0 |
|
|
| parts_manifest_url = os.getenv("SEARCH_UI_DATA_BUNDLE_PARTS_MANIFEST_URL", "").strip() |
| if parts_manifest_url: |
| if has_completed_runtime_data(data_root) and not force: |
| LOGGER.info("Using existing runtime data in %s", data_root) |
| return 0 |
| if force or has_runtime_data(data_root): |
| LOGGER.info("Clearing existing runtime data before split bundle extraction.") |
| clear_runtime_data(data_root) |
| stream_extract_parts_manifest(parts_manifest_url, data_root, cache_dir) |
| mark_runtime_data_ready(data_root) |
| LOGGER.info("Runtime data is ready in %s", data_root) |
| return 0 |
|
|
| url = os.getenv("SEARCH_UI_DATA_BUNDLE_URL", "").strip() |
| if url and not env_falsey("SEARCH_UI_STREAM_DATA_BUNDLE"): |
| if has_completed_runtime_data(data_root) and not force: |
| LOGGER.info("Using existing runtime data in %s", data_root) |
| return 0 |
| if force or has_runtime_data(data_root): |
| LOGGER.info("Clearing existing runtime data before stream extraction.") |
| clear_runtime_data(data_root) |
| stream_extract_url(url, data_root) |
| mark_runtime_data_ready(data_root) |
| LOGGER.info("Runtime data is ready in %s", data_root) |
| return 0 |
|
|
| bundle_path = resolve_bundle(cache_dir) |
| if bundle_path is None: |
| if has_runtime_data(data_root): |
| LOGGER.info("Using existing runtime data in %s", data_root) |
| else: |
| LOGGER.warning( |
| "No data bundle configured. The app will start, but searches will be empty." |
| ) |
| return 0 |
|
|
| if has_runtime_data(data_root) and not force: |
| LOGGER.info( |
| "Runtime data already exists in %s. Set SEARCH_UI_FORCE_DATA_BUNDLE=1 " |
| "to re-extract the configured bundle.", |
| data_root, |
| ) |
| return 0 |
|
|
| if force: |
| LOGGER.info("Clearing existing runtime data before bundle extraction.") |
| clear_runtime_data(data_root) |
|
|
| extract_bundle(bundle_path, data_root) |
| mark_runtime_data_ready(data_root) |
| LOGGER.info("Runtime data is ready in %s", data_root) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|