| """Benchmark script for NisabaRelief inference pipeline.""" |
|
|
| import argparse |
| import statistics |
| import time |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import numpy as np |
| from PIL import Image |
| from rich.console import Console |
| from rich.progress import ( |
| BarColumn, |
| MofNCompleteColumn, |
| Progress, |
| TextColumn, |
| TimeElapsedColumn, |
| ) |
| from rich.table import Table |
|
|
| from nisaba_relief import NisabaRelief |
| from util.load_val_dataset import load_val_dataset |
|
|
| BENCHMARK_DIR = Path(__file__).parent.parent / "data" / "benchmark" |
| BASELINE = BENCHMARK_DIR / "benchmark_baseline.png" |
| WARMUP_RUNS = 2 |
| BENCH_RUNS = 3 |
|
|
|
|
| def build_timing_table(timings: list[float], n_warmup: int) -> Table: |
| bench_timings = timings[n_warmup:] |
| mean = statistics.mean(bench_timings) |
| stdev = statistics.stdev(bench_timings) if len(bench_timings) > 1 else 0.0 |
| table = Table(title="Inference Timings") |
| table.add_column("Run", justify="right") |
| table.add_column("Time", justify="right") |
| for i, t in enumerate(timings, 1): |
| label = f"[dim]{i} (warmup)[/dim]" if i <= n_warmup else str(i - n_warmup) |
| time_str = f"[dim]{t:.2f}s[/dim]" if i <= n_warmup else f"{t:.2f}s" |
| table.add_row(label, time_str) |
| table.add_section() |
| table.add_row("[bold]Mean[/bold]", f"[bold]{mean:.2f} ± {stdev:.2f}s[/bold]") |
| return table |
|
|
|
|
| def build_diff_table(flat: np.ndarray, max_diff: int) -> Table: |
| percentile_vals = np.percentile(flat, [50, 90, 95, 96, 97, 98, 99]) |
| p98 = percentile_vals[5] |
| status = "PASS" if p98 <= 1 else "FAIL" |
| status_style = "green" if status == "PASS" else "red" |
| table = Table( |
| title=f"Pixel Diff vs Baseline — [{status_style}]{status}[/{status_style}]" |
| ) |
| table.add_column("Stat", style="bold") |
| table.add_column("Value", justify="right") |
| table.add_row("Mean", f"{flat.mean():.4f}") |
| for label, val in zip( |
| ["p50", "p90", "p95", "p96", "p97", "p98", "p99"], percentile_vals |
| ): |
| table.add_row(label, f"{val:.0f}") |
| table.add_row("Max", str(max_diff)) |
| return table |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Benchmark NisabaRelief inference pipeline" |
| ) |
| parser.add_argument( |
| "--weights-dir", |
| default=".", |
| metavar="PATH", |
| help="path to weights directory (default: .)", |
| ) |
| parser.add_argument( |
| "--device", |
| default=None, |
| metavar="DEVICE", |
| help="device to run inference on, e.g. cuda, cpu (default: cuda if available, else cpu)", |
| ) |
| args = parser.parse_args() |
|
|
| console = Console() |
| rows = load_val_dataset() |
| test_image = rows[0]["photo"] |
| max_dim = max(test_image.size) |
| if max_dim > 2048: |
| scale = 2048 / max_dim |
| new_size = (round(test_image.width * scale), round(test_image.height * scale)) |
| test_image = test_image.resize(new_size, Image.LANCZOS) |
| console.print(f"Input size: [cyan]{test_image.width}x{test_image.height}[/cyan]") |
|
|
| model_kwargs = dict(seed=42, weights_dir=Path(args.weights_dir)) |
| if args.device is not None: |
| model_kwargs["device"] = args.device |
| model = NisabaRelief(**model_kwargs) |
|
|
| timings = [] |
| output = None |
| total_runs = WARMUP_RUNS + BENCH_RUNS |
| progress = Progress( |
| TextColumn("[progress.description]{task.description}"), |
| BarColumn(), |
| MofNCompleteColumn(), |
| TimeElapsedColumn(), |
| ) |
| with progress: |
| task = progress.add_task("Benchmarking", total=total_runs) |
| for i in range(total_runs): |
| t0 = time.perf_counter() |
| result = model.process(test_image, show_pbar=False) |
| timings.append(time.perf_counter() - t0) |
| progress.advance(task) |
| if i == WARMUP_RUNS: |
| output = result |
|
|
| console.print(build_timing_table(timings, WARMUP_RUNS)) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| run_path = BENCHMARK_DIR / f"benchmark_{timestamp}.png" |
| run_path.parent.mkdir(parents=True, exist_ok=True) |
| output.save(run_path) |
| console.print(f"Run image saved to [cyan]{run_path}[/cyan]") |
|
|
| output_arr = np.array(output) |
|
|
| if not BASELINE.exists(): |
| output.save(BASELINE) |
| console.print(f"Baseline saved to [cyan]{BASELINE}[/cyan]") |
| else: |
| baseline_arr = np.array(Image.open(BASELINE)) |
| diff = np.abs(output_arr.astype(int) - baseline_arr.astype(int)) |
| flat = diff.flatten() |
| max_diff = int(flat.max()) |
| console.print(build_diff_table(flat, max_diff)) |
|
|
| if max_diff > 0: |
| diff_img = Image.fromarray( |
| np.clip(diff * (255 // max_diff), 0, 255).astype("uint8") |
| ) |
| diff_path = Path(f"benchmark_{timestamp}_diff.png") |
| diff_img.save(diff_path) |
| console.print( |
| f"Diff image saved to [cyan]{diff_path}[/cyan] (amplified {255 // max_diff}x)" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|