mmu_lsdb_benchmark / benchmark.py
Smith42's picture
Change dataset
d4b8ec6
"""
LSDB crossmatch benchmark — core logic.
Compares LOCAL (snapshot_download → disk) vs REMOTE (HTTPS) crossmatching
of two UniverseTBD MMU HATS catalogs.
"""
import time
from dataclasses import asdict, dataclass
from pathlib import Path
import lsdb
from dask.distributed import Client
from huggingface_hub import snapshot_download
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
CATALOG_LEFT = "LSDB/mmu_sdss_sdss"
CATALOG_RIGHT = "UniverseTBD/mmu_desi_provabgs"
HF_RESOLVE = "https://huggingface.co/datasets/{repo}/resolve/main"
XMATCH_RADIUS = 1.0
XMATCH_N = 1
CONE_RA, CONE_DEC, CONE_RADIUS = 150.0, 2.0, 3600.0
LOCAL_DIR = Path("/tmp/hats_local")
N_WORKERS = 2
THREADS = 1
@dataclass
class Result:
scenario: str
download_s: float = 0.0
open_s: float = 0.0
plan_s: float = 0.0
compute_s: float = 0.0
total_s: float = 0.0
n_part_left: int = 0
n_part_right: int = 0
n_rows: int = 0
peak_mb: float = 0.0
error: str = ""
def to_dict(self):
return asdict(self)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _peak_mb(client):
try:
return max(
w.get("metrics", {}).get("memory", 0)
for w in client.scheduler_info()["workers"].values()
) / 1e6
except Exception:
return 0.0
def _find_hats_root(p: Path) -> str:
markers = ["properties", "partition_info.csv", "_common_metadata"]
for depth_iter in [
[p],
sorted(d for d in p.iterdir() if d.is_dir() and not d.name.startswith(".")),
[gc for d in p.iterdir() if d.is_dir() and not d.name.startswith(".")
for gc in sorted(d.iterdir()) if gc.is_dir() and not gc.name.startswith(".")],
]:
for candidate in depth_iter:
if any((candidate / m).exists() for m in markers):
return str(candidate)
return str(p)
def _open(path, cone, label):
print(f" Opening {label}: {path}")
try:
cat = lsdb.open_catalog(path, search_filter=cone) if cone else lsdb.open_catalog(path)
except Exception:
cat = lsdb.open_catalog(path)
n = cat.get_healpix_pixels().shape[0]
print(f" -> {n} partitions, {len(cat.columns)} cols")
return cat
# ---------------------------------------------------------------------------
# Scenarios
# ---------------------------------------------------------------------------
def run_local(client, use_cone=True) -> Result:
r = Result("local")
t_total = time.perf_counter()
try:
cone = lsdb.ConeSearch(CONE_RA, CONE_DEC, radius_arcsec=CONE_RADIUS) if use_cone else None
t0 = time.perf_counter()
ld = LOCAL_DIR / CATALOG_LEFT.replace("/", "__")
rd = LOCAL_DIR / CATALOG_RIGHT.replace("/", "__")
snapshot_download(CATALOG_LEFT, repo_type="dataset", local_dir=str(ld))
snapshot_download(CATALOG_RIGHT, repo_type="dataset", local_dir=str(rd))
r.download_s = time.perf_counter() - t0
t0 = time.perf_counter()
left = _open(_find_hats_root(ld), cone, "left/local")
right = _open(_find_hats_root(rd), cone, "right/local")
r.open_s = time.perf_counter() - t0
r.n_part_left = left.get_healpix_pixels().shape[0]
r.n_part_right = right.get_healpix_pixels().shape[0]
t0 = time.perf_counter()
xm = left.crossmatch(right, radius_arcsec=XMATCH_RADIUS,
n_neighbors=XMATCH_N, suffixes=("_sdss", "_desi"))
r.plan_s = time.perf_counter() - t0
t0 = time.perf_counter()
df = xm.compute()
r.compute_s = time.perf_counter() - t0
r.n_rows = len(df)
except Exception as e:
r.error = str(e)
import traceback; traceback.print_exc()
r.total_s = time.perf_counter() - t_total
r.peak_mb = _peak_mb(client)
return r
def run_remote(client, use_cone=True) -> Result:
r = Result("remote")
t_total = time.perf_counter()
try:
cone = lsdb.ConeSearch(CONE_RA, CONE_DEC, radius_arcsec=CONE_RADIUS) if use_cone else None
lu = HF_RESOLVE.format(repo=CATALOG_LEFT)
ru = HF_RESOLVE.format(repo=CATALOG_RIGHT)
t0 = time.perf_counter()
left = _open(lu, cone, "left/remote")
right = _open(ru, cone, "right/remote")
r.open_s = time.perf_counter() - t0
r.n_part_left = left.get_healpix_pixels().shape[0]
r.n_part_right = right.get_healpix_pixels().shape[0]
t0 = time.perf_counter()
xm = left.crossmatch(right, radius_arcsec=XMATCH_RADIUS,
n_neighbors=XMATCH_N, suffixes=("_sdss", "_desi"))
r.plan_s = time.perf_counter() - t0
t0 = time.perf_counter()
df = xm.compute()
r.compute_s = time.perf_counter() - t0
r.n_rows = len(df)
except Exception as e:
r.error = str(e)
import traceback; traceback.print_exc()
r.total_s = time.perf_counter() - t_total
r.peak_mb = _peak_mb(client)
return r
def run_benchmark(use_cone=True):
"""Run both scenarios and return (local_result, remote_result, summary)."""
client = Client(n_workers=N_WORKERS, threads_per_worker=THREADS, memory_limit="auto")
try:
print("=== LOCAL ===")
local = run_local(client, use_cone)
print(f"Local done: {local.compute_s:.2f}s compute, {local.n_rows} rows\n")
print("=== REMOTE ===")
remote = run_remote(client, use_cone)
print(f"Remote done: {remote.compute_s:.2f}s compute, {remote.n_rows} rows\n")
finally:
client.close()
speedup = (remote.compute_s / local.compute_s) if local.compute_s > 0 else 0
summary = {
"compute_speedup": round(speedup, 1),
"local_compute_s": round(local.compute_s, 2),
"remote_compute_s": round(remote.compute_s, 2),
"local_download_s": round(local.download_s, 2),
"catalog_left": CATALOG_LEFT,
"catalog_right": CATALOG_RIGHT,
"cone": f"({CONE_RA}, {CONE_DEC}, r={CONE_RADIUS}\")" if use_cone else "full sky",
}
return local, remote, summary