blur-slam-bpn-code / scripts /make_qualitative_figures.py
zhaoshiwen's picture
Initial upload: BPN deblur pipeline code (scripts, triangle-splatting, BAGS, EVSSM forks)
c75b162 verified
Raw
History Blame Contribute Delete
6.85 kB
"""
Generate qualitative comparison figures:
1. Full-frame side-by-side: sharp | blur | evssm | vdiff | turtle
2. Zoom-in crops showing detail
3. TURTLE artifact analysis
"""
import argparse
import numpy as np
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
BASE = Path(__file__).parent.parent
METHODS = [
("sharp", "Sharp GT", BASE / "data/scannet_blur_proto/vddiff/test/sharp"),
("blur", "Blur input", BASE / "data/scannet_blur_proto/vddiff/test/blur"),
("evssm", "EVSSM", BASE / "data/evssm_deblurred"),
("vdiff", "VD-Diff", BASE / "data/vdiff_deblurred"),
("turtle", "TURTLE", BASE / "data/turtle_deblurred"),
]
def label_bar(text, w, h=32, bg=(30,30,30), fg=(255,255,255)):
img = Image.new("RGB", (w, h), bg)
draw = ImageDraw.Draw(img)
# center text
bbox = draw.textbbox((0, 0), text)
tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1]
draw.text(((w-tw)//2, (h-th)//2), text, fill=fg)
return np.array(img)
def psnr(a, b):
mse = np.mean((a.astype(float) - b.astype(float))**2)
return 20*np.log10(255/np.sqrt(mse)) if mse > 1e-10 else 100.0
def load(path):
return np.array(Image.open(path).convert("RGB"))
def add_label(img_arr, text, psnr_val=None):
h, w = img_arr.shape[:2]
label = f"{text}" + (f" PSNR={psnr_val:.2f}" if psnr_val else "")
bar = label_bar(label, w, h=36)
return np.vstack([bar, img_arr])
def make_comparison_row(scene, frame_idx, crop_box=None, scale=0.4):
"""Returns a row of images: [sharp | blur | evssm | vdiff | turtle]"""
frame_name = f"{frame_idx:06d}.png"
sharp_path = BASE / "data/scannet_blur_proto/vddiff/test/sharp" / scene / frame_name
if not sharp_path.exists():
return None
sharp = load(sharp_path)
cols = []
for key, label, root in METHODS:
path = root / scene / frame_name
if not path.exists():
continue
img = load(path)
# resize to same as sharp
if img.shape[:2] != sharp.shape[:2]:
img = np.array(Image.fromarray(img).resize((sharp.shape[1], sharp.shape[0])))
if crop_box:
x1, y1, x2, y2 = crop_box
img = img[y1:y2, x1:x2]
s = sharp[y1:y2, x1:x2]
else:
s = sharp
# downscale for display
if scale < 1.0 and crop_box is None:
nh = int(img.shape[0] * scale)
nw = int(img.shape[1] * scale)
img = np.array(Image.fromarray(img).resize((nw, nh), Image.LANCZOS))
s = np.array(Image.fromarray(s).resize((nw, nh), Image.LANCZOS))
p = psnr(img, s) if key != "sharp" else None
cols.append(add_label(img, label, p))
if not cols:
return None
# ensure same height
max_h = max(c.shape[0] for c in cols)
padded = []
for c in cols:
if c.shape[0] < max_h:
pad = np.zeros((max_h - c.shape[0], c.shape[1], 3), dtype=np.uint8)
c = np.vstack([c, pad])
padded.append(c)
return np.hstack(padded)
def make_scene_grid(scene, frame_idxs, out_path, crop_box=None, scale=0.4):
rows = []
for fi in frame_idxs:
row = make_comparison_row(scene, fi, crop_box=crop_box, scale=scale)
if row is not None:
rows.append(row)
if not rows:
print(f" No rows for {scene}")
return
# add scene label
scene_bar = label_bar(f"Scene: {scene}" + (f" crop={crop_box}" if crop_box else ""),
rows[0].shape[1], h=28, bg=(60,30,30))
grid = np.vstack([scene_bar] + rows)
Image.fromarray(grid).save(out_path, quality=93)
print(f" saved -> {out_path}")
def main():
p = argparse.ArgumentParser()
p.add_argument("--out-dir", default=str(BASE / "outputs/eval/qualitative"))
p.add_argument("--scenes", nargs="*", default=["scene0000_00", "scene0000_01", "scene0000_02"])
p.add_argument("--frames", nargs="*", type=int, default=[0, 15, 30, 45])
p.add_argument("--scale", type=float, default=0.4)
args = p.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# Crop regions for detail analysis (y1,x1,y2,x2 format)
crops = {
"topleft": (0, 0, 320, 480),
"center": (300, 400, 620, 880),
"bottomright": (600, 700, 968, 1296),
}
for scene in args.scenes:
print(f"\n[{scene}]")
# 1. Full-frame overview (downscaled)
make_scene_grid(scene, args.frames,
out_dir / f"{scene}_overview.jpg",
scale=args.scale)
# 2. Zoomed crops
for crop_name, box in crops.items():
x1, y1, x2, y2 = box
make_scene_grid(scene, args.frames[:2],
out_dir / f"{scene}_crop_{crop_name}.jpg",
crop_box=(x1, y1, x2, y2), scale=1.0)
# 3. TURTLE diff map (show where TURTLE changes the image)
frame_name = f"{args.frames[0]:06d}.png"
blur_path = BASE / "data/scannet_blur_proto/vddiff/test/blur" / scene / frame_name
turtle_path = BASE / "data/turtle_deblurred" / scene / frame_name
sharp_path = BASE / "data/scannet_blur_proto/vddiff/test/sharp" / scene / frame_name
if turtle_path.exists() and blur_path.exists():
blur = load(blur_path).astype(float)
turtle = load(turtle_path).astype(float)
sharp = load(sharp_path).astype(float)
# diff maps amplified 3x
diff_turtle_vs_blur = np.clip(np.abs(turtle - blur) * 3, 0, 255).astype(np.uint8)
diff_blur_vs_sharp = np.clip(np.abs(blur - sharp) * 3, 0, 255).astype(np.uint8)
diff_turtle_vs_sharp = np.clip(np.abs(turtle - sharp) * 3, 0, 255).astype(np.uint8)
nh = int(blur.shape[0] * args.scale)
nw = int(blur.shape[1] * args.scale)
def rs(a): return np.array(Image.fromarray(a).resize((nw, nh), Image.LANCZOS))
cols = [
add_label(rs(blur.astype(np.uint8)), "Blur input"),
add_label(rs(turtle.astype(np.uint8)), f"TURTLE PSNR={psnr(turtle, sharp):.2f}"),
add_label(rs(sharp.astype(np.uint8)), "Sharp GT"),
add_label(rs(diff_turtle_vs_blur), "Diff: TURTLE vs Blur (×3)"),
add_label(rs(diff_blur_vs_sharp), "Diff: Blur vs Sharp (×3)"),
add_label(rs(diff_turtle_vs_sharp), "Diff: TURTLE vs Sharp (×3)"),
]
grid = np.hstack(cols)
path = out_dir / f"{scene}_turtle_analysis.jpg"
Image.fromarray(grid).save(path, quality=93)
print(f" turtle analysis -> {path}")
print(f"\nAll figures saved to {out_dir}/")
if __name__ == "__main__":
main()