Initial upload: BPN deblur pipeline code (scripts, triangle-splatting, BAGS, EVSSM forks)
c75b162 verified | """ | |
| Step 1: Compute nima-koniq scores on blur input frames → select 10 stratified frames per scene | |
| Step 2: Score EVSSM and VD-Diff outputs on the same frames | |
| Step 3: Output comparison table + visual grid | |
| """ | |
| import argparse, json | |
| import numpy as np | |
| from pathlib import Path | |
| from PIL import Image, ImageDraw | |
| import torch | |
| import pyiqa | |
| BASE = Path(__file__).parent.parent | |
| METHODS = [ | |
| ("blur", "Blur Input", BASE / "data/scannet_blur_proto/vddiff/test/blur"), | |
| ("evssm", "EVSSM", BASE / "data/evssm_deblurred"), | |
| ("vdiff", "VD-Diff", BASE / "data/vdiff_deblurred"), | |
| ("turtle", "TURTLE", BASE / "data/turtle_deblurred"), | |
| ] | |
| def label_bar(text, w, h=32, bg=(30,30,30), fg=(255,255,255)): | |
| img = Image.new("RGB", (w, h), bg) | |
| draw = ImageDraw.Draw(img) | |
| bbox = draw.textbbox((0,0), text) | |
| tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1] | |
| draw.text(((w-tw)//2, (h-th)//2), text, fill=fg) | |
| return np.array(img) | |
| def score_image(metric, img_path, device): | |
| from torchvision import transforms | |
| img = Image.open(img_path).convert("RGB") | |
| t = transforms.ToTensor()(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| return float(metric(t).item()) | |
| def main(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--gpu", type=int, default=2) | |
| p.add_argument("--n-frames", type=int, default=10) | |
| p.add_argument("--scenes", nargs="*", default=None) | |
| p.add_argument("--out-dir", default=str(BASE / "outputs/eval/iqa_selection")) | |
| args = p.parse_args() | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| print(f"Loading nima-koniq on cuda:{args.gpu} ...") | |
| metric = pyiqa.create_metric("nima-koniq", device=device) | |
| blur_root = BASE / "data/scannet_blur_proto/vddiff/test/blur" | |
| if args.scenes: | |
| scenes = args.scenes | |
| else: | |
| scenes = sorted(p.name for p in blur_root.iterdir() if p.is_dir()) | |
| out_dir = Path(args.out_dir) | |
| (out_dir / "grids").mkdir(parents=True, exist_ok=True) | |
| all_results = {} | |
| for scene in scenes: | |
| blur_dir = blur_root / scene | |
| blur_frames = sorted(blur_dir.glob("*.png")) | |
| if not blur_frames: | |
| continue | |
| print(f"\n[{scene}] scoring {len(blur_frames)} blur frames ...") | |
| # Score all blur frames | |
| blur_scores = [] | |
| for f in blur_frames: | |
| s = score_image(metric, f, device) | |
| blur_scores.append((f.name, s)) | |
| blur_scores.sort(key=lambda x: x[1]) # sort by score ascending (most blurry first) | |
| # Select 10 stratified frames across blur level distribution | |
| n = args.n_frames | |
| idxs = np.linspace(0, len(blur_scores)-1, n, dtype=int) | |
| selected = [blur_scores[i] for i in idxs] | |
| selected_names = [s[0] for s in selected] | |
| print(f" Selected frames (blur score low→high): {[f'{n}={s:.3f}' for n,s in selected]}") | |
| # Score all methods on selected frames | |
| scene_results = {"selected_frames": selected_names, "scores": {}} | |
| for key, label, root in METHODS: | |
| method_dir = root / scene | |
| if not method_dir.exists(): | |
| continue | |
| scores = [] | |
| for fname in selected_names: | |
| fpath = method_dir / fname | |
| if fpath.exists(): | |
| s = score_image(metric, fpath, device) | |
| scores.append(s) | |
| else: | |
| scores.append(None) | |
| valid = [s for s in scores if s is not None] | |
| mean_score = float(np.mean(valid)) if valid else None | |
| scene_results["scores"][key] = { | |
| "per_frame": scores, "mean": mean_score, "label": label | |
| } | |
| print(f" {label:12s}: mean={mean_score:.4f}" if mean_score else f" {label}: N/A") | |
| all_results[scene] = scene_results | |
| # Visual grid: selected frames × methods | |
| rows = [] | |
| for fname, blur_sc in selected: | |
| cols = [] | |
| for key, label, root in METHODS: | |
| fpath = root / scene / fname | |
| if not fpath.exists(): | |
| continue | |
| img = np.array(Image.open(fpath).convert("RGB")) | |
| # downscale | |
| h, w = img.shape[:2] | |
| img = np.array(Image.fromarray(img).resize((w//3, h//3), Image.LANCZOS)) | |
| sc = scene_results["scores"].get(key, {}).get("per_frame", []) | |
| frame_idx = list(selected_names).index(fname) | |
| sc_val = sc[frame_idx] if sc and frame_idx < len(sc) else None | |
| bar_text = f"{label} {sc_val:.3f}" if sc_val else label | |
| if key == "blur": | |
| bar_text = f"Blur {blur_sc:.3f}" | |
| bar = label_bar(bar_text, img.shape[1], h=30) | |
| cols.append(np.vstack([bar, img])) | |
| if cols: | |
| max_h = max(c.shape[0] for c in cols) | |
| row = np.hstack([ | |
| np.vstack([c, np.zeros((max_h-c.shape[0], c.shape[1], 3), np.uint8)]) | |
| for c in cols | |
| ]) | |
| rows.append(row) | |
| if rows: | |
| title = label_bar(f"{scene} — nima-koniq score (frames ordered blur→sharp)", | |
| rows[0].shape[1], h=30, bg=(50,30,50)) | |
| grid = np.vstack([title] + rows) | |
| Image.fromarray(grid).save(out_dir / "grids" / f"{scene}_iqa.jpg", quality=92) | |
| print(f" grid → {out_dir}/grids/{scene}_iqa.jpg") | |
| # Summary table | |
| print(f"\n{'='*70}") | |
| print(f"{'Scene':15s} {'Blur':8s} {'EVSSM':8s} {'VD-Diff':8s} {'TURTLE':8s}") | |
| print("-"*70) | |
| method_means = {k: [] for k in ["blur","evssm","vdiff","turtle"]} | |
| for scene, res in all_results.items(): | |
| sc = res["scores"] | |
| def g(k): return f"{sc[k]['mean']:.4f}" if k in sc and sc[k]['mean'] else " — " | |
| print(f"{scene:15s} {g('blur')} {g('evssm')} {g('vdiff')} {g('turtle')}") | |
| for k in method_means: | |
| if k in sc and sc[k]["mean"]: | |
| method_means[k].append(sc[k]["mean"]) | |
| print("-"*70) | |
| print(f"{'MEAN':15s}", end=" ") | |
| for k in ["blur","evssm","vdiff","turtle"]: | |
| vals = method_means[k] | |
| print(f"{np.mean(vals):.4f} " if vals else " — ", end="") | |
| print() | |
| with open(out_dir / "iqa_results.json", "w") as f: | |
| json.dump(all_results, f, indent=2) | |
| print(f"\nResults → {out_dir}/") | |
| if __name__ == "__main__": | |
| main() | |