jw-search / deploy /huggingface /prepare_runtime_data.py
G Davies
Skip macOS metadata silently during data prep
4caa7e2 verified
#!/usr/bin/env python3
"""Prepare JW Search data for a Hugging Face Space at container startup."""
from __future__ import annotations
import logging
import json
import os
from pathlib import Path
import shutil
import tarfile
from urllib.error import URLError
from urllib.request import urlopen, urlretrieve
import zipfile
LOGGER = logging.getLogger("prepare-runtime-data")
DEFAULT_DATA_ROOT = "/data/search-ui"
URL_TIMEOUT_SECONDS = 300
READY_MARKERS = ("database.db", "json")
READY_MARKER_FILE = ".search-ui-data-ready"
DATA_DIRECTORIES = (
"json",
"subtitles",
"videos",
"publications",
"transcriptions",
)
def configure_logging() -> None:
"""Configure startup logging."""
logging.basicConfig(level=logging.INFO, format="%(message)s")
def env_truthy(name: str) -> bool:
"""Return whether an environment variable is truthy."""
return os.getenv(name, "").strip().lower() in {"1", "true", "yes", "on"}
def env_falsey(name: str) -> bool:
"""Return whether an environment variable is explicitly false."""
return os.getenv(name, "").strip().lower() in {"0", "false", "no", "off"}
def has_runtime_data(data_root: Path) -> bool:
"""Return whether the expected data bundle markers already exist."""
return any((data_root / marker).exists() for marker in READY_MARKERS)
def has_completed_runtime_data(data_root: Path) -> bool:
"""Return whether a bundle extraction completed successfully."""
return (data_root / READY_MARKER_FILE).exists()
def mark_runtime_data_ready(data_root: Path) -> None:
"""Write the marker used to distinguish complete data from partial data."""
(data_root / READY_MARKER_FILE).write_text("ready\n", encoding="utf-8")
def clear_runtime_data(data_root: Path) -> None:
"""Clear runtime data while preserving download caches."""
for child in data_root.iterdir():
if child.name == ".bundle-cache":
continue
if child.is_dir() and not child.is_symlink():
shutil.rmtree(child)
else:
child.unlink()
def iter_database_files(source_dir: Path) -> list[Path]:
"""Return root-level SQLite database files from a data source."""
return sorted(source_dir.glob("*.db"))
def link_or_copy(source: Path, target: Path, *, symlink: bool) -> None:
"""Place one source path at target by symlink or copy."""
if target.exists() or target.is_symlink():
return
target.parent.mkdir(parents=True, exist_ok=True)
if symlink:
target.symlink_to(source)
elif source.is_dir():
shutil.copytree(source, target)
else:
shutil.copy2(source, target)
def prepare_from_source_dir(source_dir: Path, data_root: Path) -> None:
"""Prepare runtime data from an already-unpacked mounted directory."""
if not source_dir.exists():
raise RuntimeError(f"Configured data source does not exist: {source_dir}")
if not source_dir.is_dir():
raise RuntimeError(f"Configured data source is not a directory: {source_dir}")
LOGGER.info("Preparing runtime data from mounted source: %s", source_dir)
for database_path in iter_database_files(source_dir):
target_path = data_root / database_path.name
if not target_path.exists():
LOGGER.info("Copying database %s", database_path.name)
shutil.copy2(database_path, target_path)
for directory_name in DATA_DIRECTORIES:
source_path = source_dir / directory_name
if not source_path.exists():
continue
target_path = data_root / directory_name
if target_path.exists() or target_path.is_symlink():
continue
LOGGER.info("Linking data directory %s", directory_name)
target_path.symlink_to(source_path)
def download_url(url: str, cache_dir: Path) -> Path:
"""Download a data bundle URL into the runtime cache."""
cache_dir.mkdir(parents=True, exist_ok=True)
filename = Path(url.split("?", 1)[0]).name or "search-ui-data.tar"
target = cache_dir / filename
temp_target = target.with_suffix(target.suffix + ".tmp")
if target.exists() and target.stat().st_size > 0:
LOGGER.info("Using cached data bundle: %s", target)
return target
LOGGER.info("Downloading data bundle from %s", url)
try:
urlretrieve(url, temp_target)
except (OSError, URLError) as exc:
raise RuntimeError(f"Unable to download data bundle from {url}: {exc}") from exc
temp_target.replace(target)
return target
def download_hub_file(cache_dir: Path) -> Path | None:
"""Download a data bundle from a Hugging Face repo when configured."""
repo_id = os.getenv("SEARCH_UI_DATA_BUNDLE_REPO_ID", "").strip()
filename = os.getenv("SEARCH_UI_DATA_BUNDLE_FILENAME", "").strip()
if not repo_id and not filename:
return None
if not repo_id or not filename:
raise RuntimeError(
"Set both SEARCH_UI_DATA_BUNDLE_REPO_ID and "
"SEARCH_UI_DATA_BUNDLE_FILENAME, or neither."
)
from huggingface_hub import hf_hub_download
repo_type = os.getenv("SEARCH_UI_DATA_BUNDLE_REPO_TYPE", "dataset").strip() or "dataset"
revision = os.getenv("SEARCH_UI_DATA_BUNDLE_REVISION", "").strip() or None
LOGGER.info("Downloading data bundle from Hugging Face %s repo %s", repo_type, repo_id)
return Path(
hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type=repo_type,
revision=revision,
local_dir=str(cache_dir / "hub"),
)
)
def resolve_bundle(cache_dir: Path) -> Path | None:
"""Resolve the configured data bundle source."""
local_path = os.getenv("SEARCH_UI_DATA_BUNDLE_PATH", "").strip()
if local_path:
bundle_path = Path(local_path).expanduser().resolve()
if not bundle_path.exists():
raise RuntimeError(f"Configured data bundle does not exist: {bundle_path}")
return bundle_path
hub_path = download_hub_file(cache_dir)
if hub_path is not None:
return hub_path
url = os.getenv("SEARCH_UI_DATA_BUNDLE_URL", "").strip()
if url:
return download_url(url, cache_dir)
return None
def validate_member_path(target_dir: Path, member_name: str) -> None:
"""Reject archive members that would write outside the data root."""
target_dir = target_dir.resolve()
destination = (target_dir / member_name).resolve()
if destination == target_dir:
return
if not str(destination).startswith(str(target_dir) + os.sep):
raise RuntimeError(f"Unsafe data bundle path rejected: {member_name}")
def is_ignorable_archive_member(member_name: str) -> bool:
"""Return whether a bundle member is safe to skip."""
parts = [part for part in member_name.split("/") if part and part != "."]
return any(part == "__MACOSX" or part.startswith("._") for part in parts)
def extract_tar_member(archive: tarfile.TarFile, member: tarfile.TarInfo, target_dir: Path) -> None:
"""Safely extract one tar member into the data root."""
if is_ignorable_archive_member(member.name):
return
validate_member_path(target_dir, member.name)
if member.issym() or member.islnk():
raise RuntimeError(f"Unsafe tar link rejected: {member.name}")
destination = (target_dir / member.name).resolve()
if member.isdir():
destination.mkdir(parents=True, exist_ok=True)
return
if not member.isfile():
raise RuntimeError(f"Unsupported tar member rejected: {member.name}")
source = archive.extractfile(member)
if source is None:
raise RuntimeError(f"Unable to read tar member: {member.name}")
destination.parent.mkdir(parents=True, exist_ok=True)
with source, destination.open("wb") as target:
shutil.copyfileobj(source, target, length=1024 * 1024)
def extract_tar(archive_path: Path, target_dir: Path) -> None:
"""Safely extract a tar archive into the data root."""
with tarfile.open(archive_path) as archive:
for member in archive:
extract_tar_member(archive, member, target_dir)
def extract_tar_stream(fileobj, target_dir: Path) -> None:
"""Safely stream-extract a tar archive into the data root."""
with tarfile.open(fileobj=fileobj, mode="r|*") as archive:
for member in archive:
extract_tar_member(archive, member, target_dir)
def read_json_url(url: str) -> dict:
"""Read a JSON document from a URL."""
try:
with urlopen(url, timeout=URL_TIMEOUT_SECONDS) as response:
return json.loads(response.read().decode("utf-8"))
except (OSError, URLError, json.JSONDecodeError) as exc:
raise RuntimeError(f"Unable to read data bundle manifest from {url}: {exc}") from exc
def resolve_part_urls(manifest: dict) -> list[str]:
"""Resolve data bundle part URLs from a manifest document."""
raw_parts = manifest.get("parts")
if not isinstance(raw_parts, list) or not raw_parts:
raise RuntimeError("Data bundle parts manifest must include a non-empty parts list.")
urls: list[str] = []
for index, item in enumerate(raw_parts, start=1):
if isinstance(item, str):
url = item
elif isinstance(item, dict) and isinstance(item.get("url"), str):
url = item["url"]
else:
raise RuntimeError(f"Invalid data bundle part at index {index}.")
if not url.startswith(("http://", "https://")):
raise RuntimeError(f"Data bundle part URL must be HTTP(S): {url}")
urls.append(url)
return urls
def download_part(url: str, target_path: Path, *, index: int, total: int) -> Path:
"""Download one bundle part to a temporary local file before reading it."""
target_path.parent.mkdir(parents=True, exist_ok=True)
temp_path = target_path.with_suffix(target_path.suffix + ".tmp")
last_error: Exception | None = None
for attempt in range(1, 4):
if temp_path.exists():
temp_path.unlink()
try:
LOGGER.info("Downloading data bundle part %s/%s", index, total)
with urlopen(url, timeout=URL_TIMEOUT_SECONDS) as response, temp_path.open("wb") as output:
shutil.copyfileobj(response, output, length=1024 * 1024)
expected_size = response.headers.get("Content-Length")
if expected_size and temp_path.stat().st_size != int(expected_size):
raise RuntimeError(
f"part size mismatch: expected {expected_size}, got {temp_path.stat().st_size}"
)
temp_path.replace(target_path)
return target_path
except (OSError, URLError, RuntimeError) as exc:
last_error = exc
LOGGER.warning(
"Failed to download data bundle part %s/%s on attempt %s: %s",
index,
total,
attempt,
exc,
)
raise RuntimeError(f"Unable to download data bundle part {index}/{total}: {last_error}")
class SequentialPartReader:
"""File-like reader that concatenates downloaded bundle part files."""
def __init__(self, part_urls: list[str], cache_dir: Path):
self.part_urls = part_urls
self.cache_dir = cache_dir
self.total = len(part_urls)
self.index = 0
self.current_file = None
self.current_path: Path | None = None
def close(self) -> None:
if self.current_file is not None:
self.current_file.close()
self.current_file = None
if self.current_path is not None and self.current_path.exists():
self.current_path.unlink()
self.current_path = None
def _open_next_part(self) -> bool:
self.close()
if self.index >= self.total:
return False
self.index += 1
target_path = self.cache_dir / f"bundle-part-{self.index:04d}"
self.current_path = download_part(
self.part_urls[self.index - 1],
target_path,
index=self.index,
total=self.total,
)
self.current_file = self.current_path.open("rb")
return True
def read(self, size: int = -1) -> bytes:
if size is None or size < 0:
size = 1024 * 1024
chunks: list[bytes] = []
remaining = size
while remaining > 0:
if self.current_file is None and not self._open_next_part():
break
chunk = self.current_file.read(remaining)
if chunk:
chunks.append(chunk)
remaining -= len(chunk)
continue
self.close()
return b"".join(chunks)
def stream_extract_url(url: str, target_dir: Path) -> None:
"""Stream a remote tar bundle directly into the data root."""
LOGGER.info("Streaming data bundle from %s", url)
try:
with urlopen(url, timeout=URL_TIMEOUT_SECONDS) as response:
extract_tar_stream(response, target_dir)
except (OSError, URLError, tarfile.TarError) as exc:
raise RuntimeError(f"Unable to stream data bundle from {url}: {exc}") from exc
def stream_extract_parts_manifest(manifest_url: str, target_dir: Path, cache_dir: Path) -> None:
"""Stream-extract a split tar bundle described by a manifest URL."""
LOGGER.info("Streaming split data bundle from manifest %s", manifest_url)
manifest = read_json_url(manifest_url)
part_urls = resolve_part_urls(manifest)
reader = SequentialPartReader(part_urls, cache_dir / "parts")
try:
extract_tar_stream(reader, target_dir)
except (OSError, URLError, tarfile.TarError) as exc:
raise RuntimeError(f"Unable to stream split data bundle: {exc}") from exc
finally:
reader.close()
def extract_zip(archive_path: Path, target_dir: Path) -> None:
"""Safely extract a zip archive into the data root."""
with zipfile.ZipFile(archive_path) as archive:
for member_name in archive.namelist():
validate_member_path(target_dir, member_name)
archive.extractall(target_dir)
def extract_bundle(archive_path: Path, target_dir: Path) -> None:
"""Extract a supported data bundle archive."""
LOGGER.info("Extracting data bundle %s into %s", archive_path, target_dir)
if zipfile.is_zipfile(archive_path):
extract_zip(archive_path, target_dir)
return
if tarfile.is_tarfile(archive_path):
extract_tar(archive_path, target_dir)
return
raise RuntimeError(f"Unsupported data bundle format: {archive_path}")
def main() -> int:
"""Prepare runtime data and return a process exit code."""
configure_logging()
data_root = Path(os.getenv("SEARCH_UI_DATA_ROOT", DEFAULT_DATA_ROOT)).resolve()
cache_dir = data_root / ".bundle-cache"
force = env_truthy("SEARCH_UI_FORCE_DATA_BUNDLE")
data_root.mkdir(parents=True, exist_ok=True)
source_dir = os.getenv("SEARCH_UI_DATA_SOURCE_DIR", "").strip()
if source_dir:
if has_runtime_data(data_root) and not force:
LOGGER.info("Using existing runtime data in %s", data_root)
else:
if force:
LOGGER.info("Clearing existing runtime data before source preparation.")
for child in data_root.iterdir():
if child.name == ".bundle-cache":
continue
if child.is_dir() and not child.is_symlink():
shutil.rmtree(child)
else:
child.unlink()
prepare_from_source_dir(Path(source_dir).expanduser().resolve(), data_root)
return 0
parts_manifest_url = os.getenv("SEARCH_UI_DATA_BUNDLE_PARTS_MANIFEST_URL", "").strip()
if parts_manifest_url:
if has_completed_runtime_data(data_root) and not force:
LOGGER.info("Using existing runtime data in %s", data_root)
return 0
if force or has_runtime_data(data_root):
LOGGER.info("Clearing existing runtime data before split bundle extraction.")
clear_runtime_data(data_root)
stream_extract_parts_manifest(parts_manifest_url, data_root, cache_dir)
mark_runtime_data_ready(data_root)
LOGGER.info("Runtime data is ready in %s", data_root)
return 0
url = os.getenv("SEARCH_UI_DATA_BUNDLE_URL", "").strip()
if url and not env_falsey("SEARCH_UI_STREAM_DATA_BUNDLE"):
if has_completed_runtime_data(data_root) and not force:
LOGGER.info("Using existing runtime data in %s", data_root)
return 0
if force or has_runtime_data(data_root):
LOGGER.info("Clearing existing runtime data before stream extraction.")
clear_runtime_data(data_root)
stream_extract_url(url, data_root)
mark_runtime_data_ready(data_root)
LOGGER.info("Runtime data is ready in %s", data_root)
return 0
bundle_path = resolve_bundle(cache_dir)
if bundle_path is None:
if has_runtime_data(data_root):
LOGGER.info("Using existing runtime data in %s", data_root)
else:
LOGGER.warning(
"No data bundle configured. The app will start, but searches will be empty."
)
return 0
if has_runtime_data(data_root) and not force:
LOGGER.info(
"Runtime data already exists in %s. Set SEARCH_UI_FORCE_DATA_BUNDLE=1 "
"to re-extract the configured bundle.",
data_root,
)
return 0
if force:
LOGGER.info("Clearing existing runtime data before bundle extraction.")
clear_runtime_data(data_root)
extract_bundle(bundle_path, data_root)
mark_runtime_data_ready(data_root)
LOGGER.info("Runtime data is ready in %s", data_root)
return 0
if __name__ == "__main__":
raise SystemExit(main())