""" Generate qualitative comparison figures: 1. Full-frame side-by-side: sharp | blur | evssm | vdiff | turtle 2. Zoom-in crops showing detail 3. TURTLE artifact analysis """ import argparse import numpy as np from pathlib import Path from PIL import Image, ImageDraw, ImageFont BASE = Path(__file__).parent.parent METHODS = [ ("sharp", "Sharp GT", BASE / "data/scannet_blur_proto/vddiff/test/sharp"), ("blur", "Blur input", BASE / "data/scannet_blur_proto/vddiff/test/blur"), ("evssm", "EVSSM", BASE / "data/evssm_deblurred"), ("vdiff", "VD-Diff", BASE / "data/vdiff_deblurred"), ("turtle", "TURTLE", BASE / "data/turtle_deblurred"), ] def label_bar(text, w, h=32, bg=(30,30,30), fg=(255,255,255)): img = Image.new("RGB", (w, h), bg) draw = ImageDraw.Draw(img) # center text bbox = draw.textbbox((0, 0), text) tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1] draw.text(((w-tw)//2, (h-th)//2), text, fill=fg) return np.array(img) def psnr(a, b): mse = np.mean((a.astype(float) - b.astype(float))**2) return 20*np.log10(255/np.sqrt(mse)) if mse > 1e-10 else 100.0 def load(path): return np.array(Image.open(path).convert("RGB")) def add_label(img_arr, text, psnr_val=None): h, w = img_arr.shape[:2] label = f"{text}" + (f" PSNR={psnr_val:.2f}" if psnr_val else "") bar = label_bar(label, w, h=36) return np.vstack([bar, img_arr]) def make_comparison_row(scene, frame_idx, crop_box=None, scale=0.4): """Returns a row of images: [sharp | blur | evssm | vdiff | turtle]""" frame_name = f"{frame_idx:06d}.png" sharp_path = BASE / "data/scannet_blur_proto/vddiff/test/sharp" / scene / frame_name if not sharp_path.exists(): return None sharp = load(sharp_path) cols = [] for key, label, root in METHODS: path = root / scene / frame_name if not path.exists(): continue img = load(path) # resize to same as sharp if img.shape[:2] != sharp.shape[:2]: img = np.array(Image.fromarray(img).resize((sharp.shape[1], sharp.shape[0]))) if crop_box: x1, y1, x2, y2 = crop_box img = img[y1:y2, x1:x2] s = sharp[y1:y2, x1:x2] else: s = sharp # downscale for display if scale < 1.0 and crop_box is None: nh = int(img.shape[0] * scale) nw = int(img.shape[1] * scale) img = np.array(Image.fromarray(img).resize((nw, nh), Image.LANCZOS)) s = np.array(Image.fromarray(s).resize((nw, nh), Image.LANCZOS)) p = psnr(img, s) if key != "sharp" else None cols.append(add_label(img, label, p)) if not cols: return None # ensure same height max_h = max(c.shape[0] for c in cols) padded = [] for c in cols: if c.shape[0] < max_h: pad = np.zeros((max_h - c.shape[0], c.shape[1], 3), dtype=np.uint8) c = np.vstack([c, pad]) padded.append(c) return np.hstack(padded) def make_scene_grid(scene, frame_idxs, out_path, crop_box=None, scale=0.4): rows = [] for fi in frame_idxs: row = make_comparison_row(scene, fi, crop_box=crop_box, scale=scale) if row is not None: rows.append(row) if not rows: print(f" No rows for {scene}") return # add scene label scene_bar = label_bar(f"Scene: {scene}" + (f" crop={crop_box}" if crop_box else ""), rows[0].shape[1], h=28, bg=(60,30,30)) grid = np.vstack([scene_bar] + rows) Image.fromarray(grid).save(out_path, quality=93) print(f" saved -> {out_path}") def main(): p = argparse.ArgumentParser() p.add_argument("--out-dir", default=str(BASE / "outputs/eval/qualitative")) p.add_argument("--scenes", nargs="*", default=["scene0000_00", "scene0000_01", "scene0000_02"]) p.add_argument("--frames", nargs="*", type=int, default=[0, 15, 30, 45]) p.add_argument("--scale", type=float, default=0.4) args = p.parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) # Crop regions for detail analysis (y1,x1,y2,x2 format) crops = { "topleft": (0, 0, 320, 480), "center": (300, 400, 620, 880), "bottomright": (600, 700, 968, 1296), } for scene in args.scenes: print(f"\n[{scene}]") # 1. Full-frame overview (downscaled) make_scene_grid(scene, args.frames, out_dir / f"{scene}_overview.jpg", scale=args.scale) # 2. Zoomed crops for crop_name, box in crops.items(): x1, y1, x2, y2 = box make_scene_grid(scene, args.frames[:2], out_dir / f"{scene}_crop_{crop_name}.jpg", crop_box=(x1, y1, x2, y2), scale=1.0) # 3. TURTLE diff map (show where TURTLE changes the image) frame_name = f"{args.frames[0]:06d}.png" blur_path = BASE / "data/scannet_blur_proto/vddiff/test/blur" / scene / frame_name turtle_path = BASE / "data/turtle_deblurred" / scene / frame_name sharp_path = BASE / "data/scannet_blur_proto/vddiff/test/sharp" / scene / frame_name if turtle_path.exists() and blur_path.exists(): blur = load(blur_path).astype(float) turtle = load(turtle_path).astype(float) sharp = load(sharp_path).astype(float) # diff maps amplified 3x diff_turtle_vs_blur = np.clip(np.abs(turtle - blur) * 3, 0, 255).astype(np.uint8) diff_blur_vs_sharp = np.clip(np.abs(blur - sharp) * 3, 0, 255).astype(np.uint8) diff_turtle_vs_sharp = np.clip(np.abs(turtle - sharp) * 3, 0, 255).astype(np.uint8) nh = int(blur.shape[0] * args.scale) nw = int(blur.shape[1] * args.scale) def rs(a): return np.array(Image.fromarray(a).resize((nw, nh), Image.LANCZOS)) cols = [ add_label(rs(blur.astype(np.uint8)), "Blur input"), add_label(rs(turtle.astype(np.uint8)), f"TURTLE PSNR={psnr(turtle, sharp):.2f}"), add_label(rs(sharp.astype(np.uint8)), "Sharp GT"), add_label(rs(diff_turtle_vs_blur), "Diff: TURTLE vs Blur (×3)"), add_label(rs(diff_blur_vs_sharp), "Diff: Blur vs Sharp (×3)"), add_label(rs(diff_turtle_vs_sharp), "Diff: TURTLE vs Sharp (×3)"), ] grid = np.hstack(cols) path = out_dir / f"{scene}_turtle_analysis.jpg" Image.fromarray(grid).save(path, quality=93) print(f" turtle analysis -> {path}") print(f"\nAll figures saved to {out_dir}/") if __name__ == "__main__": main()