blur-slam-bpn-code / scripts /iqa_frame_selection.py
zhaoshiwen's picture
Initial upload: BPN deblur pipeline code (scripts, triangle-splatting, BAGS, EVSSM forks)
c75b162 verified
Raw
History Blame Contribute Delete
6.53 kB
"""
Step 1: Compute nima-koniq scores on blur input frames → select 10 stratified frames per scene
Step 2: Score EVSSM and VD-Diff outputs on the same frames
Step 3: Output comparison table + visual grid
"""
import argparse, json
import numpy as np
from pathlib import Path
from PIL import Image, ImageDraw
import torch
import pyiqa
BASE = Path(__file__).parent.parent
METHODS = [
("blur", "Blur Input", BASE / "data/scannet_blur_proto/vddiff/test/blur"),
("evssm", "EVSSM", BASE / "data/evssm_deblurred"),
("vdiff", "VD-Diff", BASE / "data/vdiff_deblurred"),
("turtle", "TURTLE", BASE / "data/turtle_deblurred"),
]
def label_bar(text, w, h=32, bg=(30,30,30), fg=(255,255,255)):
img = Image.new("RGB", (w, h), bg)
draw = ImageDraw.Draw(img)
bbox = draw.textbbox((0,0), text)
tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1]
draw.text(((w-tw)//2, (h-th)//2), text, fill=fg)
return np.array(img)
def score_image(metric, img_path, device):
from torchvision import transforms
img = Image.open(img_path).convert("RGB")
t = transforms.ToTensor()(img).unsqueeze(0).to(device)
with torch.no_grad():
return float(metric(t).item())
def main():
p = argparse.ArgumentParser()
p.add_argument("--gpu", type=int, default=2)
p.add_argument("--n-frames", type=int, default=10)
p.add_argument("--scenes", nargs="*", default=None)
p.add_argument("--out-dir", default=str(BASE / "outputs/eval/iqa_selection"))
args = p.parse_args()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Loading nima-koniq on cuda:{args.gpu} ...")
metric = pyiqa.create_metric("nima-koniq", device=device)
blur_root = BASE / "data/scannet_blur_proto/vddiff/test/blur"
if args.scenes:
scenes = args.scenes
else:
scenes = sorted(p.name for p in blur_root.iterdir() if p.is_dir())
out_dir = Path(args.out_dir)
(out_dir / "grids").mkdir(parents=True, exist_ok=True)
all_results = {}
for scene in scenes:
blur_dir = blur_root / scene
blur_frames = sorted(blur_dir.glob("*.png"))
if not blur_frames:
continue
print(f"\n[{scene}] scoring {len(blur_frames)} blur frames ...")
# Score all blur frames
blur_scores = []
for f in blur_frames:
s = score_image(metric, f, device)
blur_scores.append((f.name, s))
blur_scores.sort(key=lambda x: x[1]) # sort by score ascending (most blurry first)
# Select 10 stratified frames across blur level distribution
n = args.n_frames
idxs = np.linspace(0, len(blur_scores)-1, n, dtype=int)
selected = [blur_scores[i] for i in idxs]
selected_names = [s[0] for s in selected]
print(f" Selected frames (blur score low→high): {[f'{n}={s:.3f}' for n,s in selected]}")
# Score all methods on selected frames
scene_results = {"selected_frames": selected_names, "scores": {}}
for key, label, root in METHODS:
method_dir = root / scene
if not method_dir.exists():
continue
scores = []
for fname in selected_names:
fpath = method_dir / fname
if fpath.exists():
s = score_image(metric, fpath, device)
scores.append(s)
else:
scores.append(None)
valid = [s for s in scores if s is not None]
mean_score = float(np.mean(valid)) if valid else None
scene_results["scores"][key] = {
"per_frame": scores, "mean": mean_score, "label": label
}
print(f" {label:12s}: mean={mean_score:.4f}" if mean_score else f" {label}: N/A")
all_results[scene] = scene_results
# Visual grid: selected frames × methods
rows = []
for fname, blur_sc in selected:
cols = []
for key, label, root in METHODS:
fpath = root / scene / fname
if not fpath.exists():
continue
img = np.array(Image.open(fpath).convert("RGB"))
# downscale
h, w = img.shape[:2]
img = np.array(Image.fromarray(img).resize((w//3, h//3), Image.LANCZOS))
sc = scene_results["scores"].get(key, {}).get("per_frame", [])
frame_idx = list(selected_names).index(fname)
sc_val = sc[frame_idx] if sc and frame_idx < len(sc) else None
bar_text = f"{label} {sc_val:.3f}" if sc_val else label
if key == "blur":
bar_text = f"Blur {blur_sc:.3f}"
bar = label_bar(bar_text, img.shape[1], h=30)
cols.append(np.vstack([bar, img]))
if cols:
max_h = max(c.shape[0] for c in cols)
row = np.hstack([
np.vstack([c, np.zeros((max_h-c.shape[0], c.shape[1], 3), np.uint8)])
for c in cols
])
rows.append(row)
if rows:
title = label_bar(f"{scene} — nima-koniq score (frames ordered blur→sharp)",
rows[0].shape[1], h=30, bg=(50,30,50))
grid = np.vstack([title] + rows)
Image.fromarray(grid).save(out_dir / "grids" / f"{scene}_iqa.jpg", quality=92)
print(f" grid → {out_dir}/grids/{scene}_iqa.jpg")
# Summary table
print(f"\n{'='*70}")
print(f"{'Scene':15s} {'Blur':8s} {'EVSSM':8s} {'VD-Diff':8s} {'TURTLE':8s}")
print("-"*70)
method_means = {k: [] for k in ["blur","evssm","vdiff","turtle"]}
for scene, res in all_results.items():
sc = res["scores"]
def g(k): return f"{sc[k]['mean']:.4f}" if k in sc and sc[k]['mean'] else " — "
print(f"{scene:15s} {g('blur')} {g('evssm')} {g('vdiff')} {g('turtle')}")
for k in method_means:
if k in sc and sc[k]["mean"]:
method_means[k].append(sc[k]["mean"])
print("-"*70)
print(f"{'MEAN':15s}", end=" ")
for k in ["blur","evssm","vdiff","turtle"]:
vals = method_means[k]
print(f"{np.mean(vals):.4f} " if vals else " — ", end="")
print()
with open(out_dir / "iqa_results.json", "w") as f:
json.dump(all_results, f, indent=2)
print(f"\nResults → {out_dir}/")
if __name__ == "__main__":
main()