Smith42's picture
Changes to be committed:
2fe731a
"""
FastAPI app for the HF Space — serves a web UI and runs the benchmark.
"""
import asyncio
from string import Template
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
# Global state for benchmark results
state = {
"status": "idle", # idle | running | done | error
"local": None,
"remote": None,
"summary": None,
"log": [],
}
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
app = FastAPI(lifespan=lifespan)
PAGE = Template("""<!DOCTYPE html>
<html>
<head>
<title>LSDB Crossmatch Benchmark</title>
<style>
body { font-family: system-ui, sans-serif; max-width: 820px; margin: 2rem auto; padding: 0 1rem; color: #e0e0e0; background: #1a1a2e; }
h1 { color: #a78bfa; }
h2 { color: #818cf8; margin-top: 2rem; }
.card { background: #16213e; border-radius: 8px; padding: 1.2rem; margin: 1rem 0; border: 1px solid #2a2a4a; }
table { border-collapse: collapse; width: 100%; }
th, td { text-align: left; padding: 0.5rem 0.8rem; border-bottom: 1px solid #2a2a4a; }
th { color: #a78bfa; }
.key { color: #fbbf24; font-weight: bold; font-size: 1.1em; }
.btn { background: #6d28d9; color: white; border: none; padding: 0.7rem 1.5rem; border-radius: 6px; cursor: pointer; font-size: 1rem; }
.btn:hover { background: #7c3aed; }
.btn:disabled { background: #4a4a6a; cursor: not-allowed; }
.status { padding: 0.5rem 1rem; border-radius: 4px; display: inline-block; margin: 0.5rem 0; }
.idle { background: #1e3a5f; }
.running { background: #854d0e; }
.done { background: #166534; }
.error { background: #991b1b; }
pre { background: #0f0f23; padding: 1rem; border-radius: 6px; overflow-x: auto; font-size: 0.85rem; max-height: 400px; overflow-y: auto; }
.speedup { font-size: 2rem; color: #34d399; font-weight: bold; }
.datasets { color: #94a3b8; font-size: 0.9rem; }
</style>
</head>
<body>
<h1>&#128301; LSDB Crossmatch Colocation Benchmark</h1>
<p class="datasets">
<strong>Left:</strong> UniverseTBD/mmu_sdss_sdss (~812k rows) &nbsp;&times;&nbsp;
<strong>Right:</strong> UniverseTBD/mmu_desi_edr_sv3 (~1.16M rows)
</p>
<div class="card">
<p><strong>What this tests:</strong> Does co-locating compute with HATS data
on HF Hub eliminate the network bottleneck for crossmatching?</p>
<p><strong>LOCAL</strong> = snapshot_download &rarr; crossmatch from disk (simulates server-side)<br>
<strong>REMOTE</strong> = crossmatch via HTTPS resolve URLs (simulates external user)</p>
</div>
<button class="btn" id="runBtn" onclick="startBenchmark()">Run Benchmark</button>
<span class="status $status_class" id="status">$status_text</span>
$results_html
<h2>Log</h2>
<pre id="log">$log_text</pre>
<script>
async function startBenchmark() {
document.getElementById('runBtn').disabled = true;
document.getElementById('status').textContent = 'Running...';
document.getElementById('status').className = 'status running';
const resp = await fetch('/run', {method: 'POST'});
const poll = setInterval(async () => {
const r = await fetch('/status');
const d = await r.json();
if (d.status === 'done' || d.status === 'error') {
clearInterval(poll);
location.reload();
}
document.getElementById('log').textContent = d.log.join('\\n');
}, 2000);
}
</script>
</body>
</html>""")
def render_results():
if state["status"] != "done" or not state["summary"]:
return ""
s = state["summary"]
local = state["local"]
remote = state["remote"]
speedup = s.get("compute_speedup", 0)
if speedup > 3:
conclusion = "Strong colocation advantage &mdash; server-side crossmatching is worth pursuing."
elif speedup > 1.5:
conclusion = "Moderate advantage &mdash; grows with catalog size."
else:
conclusion = "Minimal difference at this scale &mdash; try larger catalogs or full sky."
return f"""
<h2>Results</h2>
<div class="card" style="text-align:center;">
<p>Colocation speedup (compute phase)</p>
<p class="speedup">{speedup}&times;</p>
<p>{conclusion}</p>
</div>
<div class="card">
<table>
<tr><th>Metric</th><th>LOCAL (disk)</th><th>REMOTE (HTTPS)</th></tr>
<tr><td>Download / open</td>
<td>{local['download_s']:.1f}s + {local['open_s']:.1f}s</td>
<td>&mdash; / {remote['open_s']:.1f}s</td></tr>
<tr><td class="key">Compute (I/O + CPU)</td>
<td class="key">{local['compute_s']:.2f}s</td>
<td class="key">{remote['compute_s']:.2f}s</td></tr>
<tr><td>Total wall time</td>
<td>{local['total_s']:.1f}s</td>
<td>{remote['total_s']:.1f}s</td></tr>
<tr><td>Result rows</td>
<td>{local['n_rows']}</td>
<td>{remote['n_rows']}</td></tr>
<tr><td>Partitions (L &times; R)</td>
<td>{local['n_part_left']} &times; {local['n_part_right']}</td>
<td>{remote['n_part_left']} &times; {remote['n_part_right']}</td></tr>
<tr><td>Peak memory</td>
<td>{local['peak_mb']:.0f} MB</td>
<td>{remote['peak_mb']:.0f} MB</td></tr>
</table>
</div>
<div class="card">
<p><strong>Cone:</strong> {s['cone']}</p>
<p><strong>Catalogs:</strong> {s['catalog_left']} &times; {s['catalog_right']}</p>
</div>
"""
@app.get("/", response_class=HTMLResponse)
async def home():
status_map = {
"idle": ("Ready to run", "idle"),
"running": ("Running benchmark...", "running"),
"done": ("Complete", "done"),
"error": ("Error &mdash; check logs", "error"),
}
text, cls = status_map.get(state["status"], ("Unknown", "idle"))
results_html = render_results()
log_text = "\n".join(state["log"][-100:]) or "(no logs yet)"
html = PAGE.substitute(
status_text=text,
status_class=cls,
results_html=results_html,
log_text=log_text,
)
return html
@app.get("/status")
async def status():
return {
"status": state["status"],
"log": state["log"][-50:],
"summary": state["summary"],
}
@app.post("/run")
async def run():
if state["status"] == "running":
return {"status": "already running"}
state["status"] = "running"
state["log"] = ["Starting benchmark..."]
state["local"] = None
state["remote"] = None
state["summary"] = None
asyncio.get_event_loop().run_in_executor(None, _run_benchmark)
return {"status": "started"}
def _log(msg):
print(msg)
state["log"].append(msg)
def _run_benchmark():
import io
import sys
class LogCapture(io.TextIOBase):
def write(self, s):
if s.strip():
state["log"].append(s.rstrip())
sys.__stdout__.write(s)
return len(s)
old_stdout = sys.stdout
sys.stdout = LogCapture()
try:
from benchmark import run_benchmark
local, remote, summary = run_benchmark(use_cone=True)
state["local"] = local.to_dict()
state["remote"] = remote.to_dict()
state["summary"] = summary
state["status"] = "done"
_log(f"Done! Speedup: {summary['compute_speedup']}x")
except Exception as e:
state["status"] = "error"
state["log"].append(f"ERROR: {e}")
import traceback
state["log"].append(traceback.format_exc())
finally:
sys.stdout = old_stdout