Spaces:
Paused
Paused
| """ | |
| LSDB crossmatch benchmark — core logic. | |
| Compares LOCAL (snapshot_download → disk) vs REMOTE (HTTPS) crossmatching | |
| of two UniverseTBD MMU HATS catalogs. | |
| """ | |
| import time | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| import lsdb | |
| from dask.distributed import Client | |
| from huggingface_hub import snapshot_download | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| CATALOG_LEFT = "LSDB/mmu_sdss_sdss" | |
| CATALOG_RIGHT = "UniverseTBD/mmu_desi_provabgs" | |
| HF_RESOLVE = "https://huggingface.co/datasets/{repo}/resolve/main" | |
| XMATCH_RADIUS = 1.0 | |
| XMATCH_N = 1 | |
| CONE_RA, CONE_DEC, CONE_RADIUS = 150.0, 2.0, 3600.0 | |
| LOCAL_DIR = Path("/tmp/hats_local") | |
| N_WORKERS = 2 | |
| THREADS = 1 | |
| class Result: | |
| scenario: str | |
| download_s: float = 0.0 | |
| open_s: float = 0.0 | |
| plan_s: float = 0.0 | |
| compute_s: float = 0.0 | |
| total_s: float = 0.0 | |
| n_part_left: int = 0 | |
| n_part_right: int = 0 | |
| n_rows: int = 0 | |
| peak_mb: float = 0.0 | |
| error: str = "" | |
| def to_dict(self): | |
| return asdict(self) | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _peak_mb(client): | |
| try: | |
| return max( | |
| w.get("metrics", {}).get("memory", 0) | |
| for w in client.scheduler_info()["workers"].values() | |
| ) / 1e6 | |
| except Exception: | |
| return 0.0 | |
| def _find_hats_root(p: Path) -> str: | |
| markers = ["properties", "partition_info.csv", "_common_metadata"] | |
| for depth_iter in [ | |
| [p], | |
| sorted(d for d in p.iterdir() if d.is_dir() and not d.name.startswith(".")), | |
| [gc for d in p.iterdir() if d.is_dir() and not d.name.startswith(".") | |
| for gc in sorted(d.iterdir()) if gc.is_dir() and not gc.name.startswith(".")], | |
| ]: | |
| for candidate in depth_iter: | |
| if any((candidate / m).exists() for m in markers): | |
| return str(candidate) | |
| return str(p) | |
| def _open(path, cone, label): | |
| print(f" Opening {label}: {path}") | |
| try: | |
| cat = lsdb.open_catalog(path, search_filter=cone) if cone else lsdb.open_catalog(path) | |
| except Exception: | |
| cat = lsdb.open_catalog(path) | |
| n = cat.get_healpix_pixels().shape[0] | |
| print(f" -> {n} partitions, {len(cat.columns)} cols") | |
| return cat | |
| # --------------------------------------------------------------------------- | |
| # Scenarios | |
| # --------------------------------------------------------------------------- | |
| def run_local(client, use_cone=True) -> Result: | |
| r = Result("local") | |
| t_total = time.perf_counter() | |
| try: | |
| cone = lsdb.ConeSearch(CONE_RA, CONE_DEC, radius_arcsec=CONE_RADIUS) if use_cone else None | |
| t0 = time.perf_counter() | |
| ld = LOCAL_DIR / CATALOG_LEFT.replace("/", "__") | |
| rd = LOCAL_DIR / CATALOG_RIGHT.replace("/", "__") | |
| snapshot_download(CATALOG_LEFT, repo_type="dataset", local_dir=str(ld)) | |
| snapshot_download(CATALOG_RIGHT, repo_type="dataset", local_dir=str(rd)) | |
| r.download_s = time.perf_counter() - t0 | |
| t0 = time.perf_counter() | |
| left = _open(_find_hats_root(ld), cone, "left/local") | |
| right = _open(_find_hats_root(rd), cone, "right/local") | |
| r.open_s = time.perf_counter() - t0 | |
| r.n_part_left = left.get_healpix_pixels().shape[0] | |
| r.n_part_right = right.get_healpix_pixels().shape[0] | |
| t0 = time.perf_counter() | |
| xm = left.crossmatch(right, radius_arcsec=XMATCH_RADIUS, | |
| n_neighbors=XMATCH_N, suffixes=("_sdss", "_desi")) | |
| r.plan_s = time.perf_counter() - t0 | |
| t0 = time.perf_counter() | |
| df = xm.compute() | |
| r.compute_s = time.perf_counter() - t0 | |
| r.n_rows = len(df) | |
| except Exception as e: | |
| r.error = str(e) | |
| import traceback; traceback.print_exc() | |
| r.total_s = time.perf_counter() - t_total | |
| r.peak_mb = _peak_mb(client) | |
| return r | |
| def run_remote(client, use_cone=True) -> Result: | |
| r = Result("remote") | |
| t_total = time.perf_counter() | |
| try: | |
| cone = lsdb.ConeSearch(CONE_RA, CONE_DEC, radius_arcsec=CONE_RADIUS) if use_cone else None | |
| lu = HF_RESOLVE.format(repo=CATALOG_LEFT) | |
| ru = HF_RESOLVE.format(repo=CATALOG_RIGHT) | |
| t0 = time.perf_counter() | |
| left = _open(lu, cone, "left/remote") | |
| right = _open(ru, cone, "right/remote") | |
| r.open_s = time.perf_counter() - t0 | |
| r.n_part_left = left.get_healpix_pixels().shape[0] | |
| r.n_part_right = right.get_healpix_pixels().shape[0] | |
| t0 = time.perf_counter() | |
| xm = left.crossmatch(right, radius_arcsec=XMATCH_RADIUS, | |
| n_neighbors=XMATCH_N, suffixes=("_sdss", "_desi")) | |
| r.plan_s = time.perf_counter() - t0 | |
| t0 = time.perf_counter() | |
| df = xm.compute() | |
| r.compute_s = time.perf_counter() - t0 | |
| r.n_rows = len(df) | |
| except Exception as e: | |
| r.error = str(e) | |
| import traceback; traceback.print_exc() | |
| r.total_s = time.perf_counter() - t_total | |
| r.peak_mb = _peak_mb(client) | |
| return r | |
| def run_benchmark(use_cone=True): | |
| """Run both scenarios and return (local_result, remote_result, summary).""" | |
| client = Client(n_workers=N_WORKERS, threads_per_worker=THREADS, memory_limit="auto") | |
| try: | |
| print("=== LOCAL ===") | |
| local = run_local(client, use_cone) | |
| print(f"Local done: {local.compute_s:.2f}s compute, {local.n_rows} rows\n") | |
| print("=== REMOTE ===") | |
| remote = run_remote(client, use_cone) | |
| print(f"Remote done: {remote.compute_s:.2f}s compute, {remote.n_rows} rows\n") | |
| finally: | |
| client.close() | |
| speedup = (remote.compute_s / local.compute_s) if local.compute_s > 0 else 0 | |
| summary = { | |
| "compute_speedup": round(speedup, 1), | |
| "local_compute_s": round(local.compute_s, 2), | |
| "remote_compute_s": round(remote.compute_s, 2), | |
| "local_download_s": round(local.download_s, 2), | |
| "catalog_left": CATALOG_LEFT, | |
| "catalog_right": CATALOG_RIGHT, | |
| "cone": f"({CONE_RA}, {CONE_DEC}, r={CONE_RADIUS}\")" if use_cone else "full sky", | |
| } | |
| return local, remote, summary | |