"""Benchmark script for NisabaRelief inference pipeline.""" import argparse import statistics import time from datetime import datetime from pathlib import Path import numpy as np from PIL import Image from rich.console import Console from rich.progress import ( BarColumn, MofNCompleteColumn, Progress, TextColumn, TimeElapsedColumn, ) from rich.table import Table from nisaba_relief import NisabaRelief from util.load_val_dataset import load_val_dataset BENCHMARK_DIR = Path(__file__).parent.parent / "data" / "benchmark" BASELINE = BENCHMARK_DIR / "benchmark_baseline.png" WARMUP_RUNS = 2 BENCH_RUNS = 3 def build_timing_table(timings: list[float], n_warmup: int) -> Table: bench_timings = timings[n_warmup:] mean = statistics.mean(bench_timings) stdev = statistics.stdev(bench_timings) if len(bench_timings) > 1 else 0.0 table = Table(title="Inference Timings") table.add_column("Run", justify="right") table.add_column("Time", justify="right") for i, t in enumerate(timings, 1): label = f"[dim]{i} (warmup)[/dim]" if i <= n_warmup else str(i - n_warmup) time_str = f"[dim]{t:.2f}s[/dim]" if i <= n_warmup else f"{t:.2f}s" table.add_row(label, time_str) table.add_section() table.add_row("[bold]Mean[/bold]", f"[bold]{mean:.2f} ± {stdev:.2f}s[/bold]") return table def build_diff_table(flat: np.ndarray, max_diff: int) -> Table: percentile_vals = np.percentile(flat, [50, 90, 95, 96, 97, 98, 99]) p98 = percentile_vals[5] status = "PASS" if p98 <= 1 else "FAIL" status_style = "green" if status == "PASS" else "red" table = Table( title=f"Pixel Diff vs Baseline — [{status_style}]{status}[/{status_style}]" ) table.add_column("Stat", style="bold") table.add_column("Value", justify="right") table.add_row("Mean", f"{flat.mean():.4f}") for label, val in zip( ["p50", "p90", "p95", "p96", "p97", "p98", "p99"], percentile_vals ): table.add_row(label, f"{val:.0f}") table.add_row("Max", str(max_diff)) return table def main(): parser = argparse.ArgumentParser( description="Benchmark NisabaRelief inference pipeline" ) parser.add_argument( "--weights-dir", default=".", metavar="PATH", help="path to weights directory (default: .)", ) parser.add_argument( "--device", default=None, metavar="DEVICE", help="device to run inference on, e.g. cuda, cpu (default: cuda if available, else cpu)", ) args = parser.parse_args() console = Console() rows = load_val_dataset() test_image = rows[0]["photo"] max_dim = max(test_image.size) if max_dim > 2048: scale = 2048 / max_dim new_size = (round(test_image.width * scale), round(test_image.height * scale)) test_image = test_image.resize(new_size, Image.LANCZOS) console.print(f"Input size: [cyan]{test_image.width}x{test_image.height}[/cyan]") model_kwargs = dict(seed=42, weights_dir=Path(args.weights_dir)) if args.device is not None: model_kwargs["device"] = args.device model = NisabaRelief(**model_kwargs) timings = [] output = None total_runs = WARMUP_RUNS + BENCH_RUNS progress = Progress( TextColumn("[progress.description]{task.description}"), BarColumn(), MofNCompleteColumn(), TimeElapsedColumn(), ) with progress: task = progress.add_task("Benchmarking", total=total_runs) for i in range(total_runs): t0 = time.perf_counter() result = model.process(test_image, show_pbar=False) timings.append(time.perf_counter() - t0) progress.advance(task) if i == WARMUP_RUNS: output = result console.print(build_timing_table(timings, WARMUP_RUNS)) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") run_path = BENCHMARK_DIR / f"benchmark_{timestamp}.png" run_path.parent.mkdir(parents=True, exist_ok=True) output.save(run_path) console.print(f"Run image saved to [cyan]{run_path}[/cyan]") output_arr = np.array(output) if not BASELINE.exists(): output.save(BASELINE) console.print(f"Baseline saved to [cyan]{BASELINE}[/cyan]") else: baseline_arr = np.array(Image.open(BASELINE)) diff = np.abs(output_arr.astype(int) - baseline_arr.astype(int)) flat = diff.flatten() max_diff = int(flat.max()) console.print(build_diff_table(flat, max_diff)) if max_diff > 0: diff_img = Image.fromarray( np.clip(diff * (255 // max_diff), 0, 255).astype("uint8") ) diff_path = Path(f"benchmark_{timestamp}_diff.png") diff_img.save(diff_path) console.print( f"Diff image saved to [cyan]{diff_path}[/cyan] (amplified {255 // max_diff}x)" ) if __name__ == "__main__": main()