""" Evaluate NisabaRelief on the validation set, optionally sweeping over step counts. Usage: python evaluation.py # full dataset, num_steps=2 python evaluation.py --sweep # subset, steps=[1,2,4,8] """ import argparse import time from datetime import timedelta from pathlib import Path import numpy as np from PIL import Image from rich.console import Console, Group from rich.live import Live from rich.progress import ( BarColumn, MofNCompleteColumn, Progress, TextColumn, TimeElapsedColumn, ) from rich.table import Table from nisaba_relief import NisabaRelief from util.metrics import compute_metrics, METRIC_NAMES, LABELS from util.load_val_dataset import load_val_dataset SWEEP_STEPS = [1, 2, 4, 8] DEFAULT_STEPS = 2 SWEEP_STRIDE = 4 SWEEP_MAX = 175 EVALS_DIR = Path(__file__).parent.parent / "data" / "evals" def _eta(n_done: int, n_total: int, elapsed: float) -> str: if n_done >= n_total > 0: return "Done" if n_done > 0: return str(timedelta(seconds=int(elapsed / n_done * (n_total - n_done)))) return "?" def build_table( results: dict, n_done: int = 0, n_total: int = 0, elapsed: float = 0.0, ) -> Table: eta = _eta(n_done, n_total, elapsed) steps = list(results.keys()) table = Table(title=f"Results — ETA: {eta}") table.add_column("Metric", style="bold") for s in steps: table.add_column(f"Steps={s}", justify="right") for name in METRIC_NAMES: cells = [] for s in steps: arr = np.array(results[s][name]) if len(arr) == 0: cells.append("—") elif name in ("psnr", "psnr_hvsm", "sre"): cells.append(f"{arr.mean():.2f} ± {arr.std():.2f} dB") else: cells.append(f"{arr.mean():.4f} ± {arr.std():.4f}") table.add_row(LABELS[name], *cells) return table def load_grayscale(img: Image.Image) -> np.ndarray: return np.array(img.convert("L")) def main(): parser = argparse.ArgumentParser(description="Evaluate NisabaRelief model") parser.add_argument( "--weights-dir", default=".", metavar="PATH", help="path to weights directory (default: .)", ) parser.add_argument( "--sweep", action="store_true", help="sweep over steps=[1,2,4,8] on a dataset subset", ) args = parser.parse_args() rows = load_val_dataset() if args.sweep: rows = rows.select( range(0, min(len(rows), SWEEP_MAX * SWEEP_STRIDE), SWEEP_STRIDE) ) steps_to_run = SWEEP_STEPS else: steps_to_run = [DEFAULT_STEPS] results = {s: {m: [] for m in METRIC_NAMES} for s in steps_to_run} model = NisabaRelief(seed=42, batch_size=4, weights_dir=Path(args.weights_dir)) progress = Progress( TextColumn("[progress.description]{task.description}"), BarColumn(), MofNCompleteColumn(), TimeElapsedColumn(), TextColumn("[cyan]{task.fields[hs_number]}"), ) task_desc = "Step Sweep" if args.sweep else "Evaluating" task = progress.add_task(task_desc, total=len(rows), hs_number="") start_time = time.monotonic() with Live( Group(progress, build_table(results)), refresh_per_second=4, transient=True, ) as live: for n_done, row in enumerate(rows): progress.update(task, hs_number=row["hs_number"]) gt = load_grayscale(row["msii"]) for num_steps in steps_to_run: model.num_steps = num_steps save_name = f"{row['hs_number']}_photo_fullview_{int(row['variation']):02d}-step{num_steps}.png" save_path = EVALS_DIR / save_name save_path.parent.mkdir(parents=True, exist_ok=True) if save_path.exists(): pred_img = Image.open(save_path) else: pred_img = model.process(row["photo"], show_pbar=False) pred_img.save(save_path) pred = load_grayscale(pred_img) pred_img.close() if pred.shape != gt.shape: pred = np.array( Image.fromarray(pred).resize( (gt.shape[1], gt.shape[0]), Image.LANCZOS ) ) m = compute_metrics(pred, gt) for name, val in m.items(): results[num_steps][name].append(val) elapsed = time.monotonic() - start_time live.update( Group(progress, build_table(results, n_done + 1, len(rows), elapsed)) ) progress.advance(task) final_elapsed = time.monotonic() - start_time Console().print(build_table(results, len(rows), len(rows), final_elapsed)) if __name__ == "__main__": main()