Tejaswi Tripathi
Quantitative tests
13d40ab
Raw
History Blame Contribute Delete
17.3 kB
#!/usr/bin/env python3
"""Benchmark edge artifacts for the prototype rendering pipeline.
For each JPG in cache/, runs the same NN + Gaussian-background pipeline as
prototype.py, then compares edge artifacts against the original image:
1. Flat Gaussian background vs original
2. CoC-weighted NN render vs original
Metrics are computed in regions where compositing is most likely to leave
visible halos: in-focus pixels (should match the original), the CoC transition
band, and a narrow ring around the Gaussian mask boundary (CoC threshold).
Usage (from repo root):
python quantitative-tests/benchmark_edge_artifacts.py
python quantitative-tests/benchmark_edge_artifacts.py --max-side 1536
"""
from __future__ import annotations
import argparse
import csv
import io
import json
import sys
from dataclasses import asdict, dataclass
from pathlib import Path
import boto3
import numpy as np
import torch
from PIL import Image
from skimage.filters import gaussian, sobel
from skimage.morphology import dilation, disk, erosion
from skimage.transform import resize
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))
from data_preprocessing import ( # noqa: E402
COC_PX_MAX,
FOCAL_LENGTH_MM_MAX,
F_STOP_MAX,
TARGET_SIZE,
)
from depth_anything_inference import ( # noqa: E402
generate_pseudo_coc_from_relative_depth,
load_depth_anything,
predict_relative_depth,
)
from models.RendererNet import RendererNet # noqa: E402
# Match prototype.py defaults.
S3_BUCKET = "tejas-blender-bucket"
RENDERER_KEY = "defocus-checkpoints/renderer-net/best_renderer.pth"
F_STOP = 1.2
FOCAL_LENGTH_MM = 6.765
DEPTH_ENCODER = "vitb"
DEPTH_CHECKPOINT = REPO_ROOT / "checkpoints" / "depth_anything_v2_vitb.pth"
COC_MAX_PX = 4.0
COC_FOCUS_THRESHOLD_PX = 0.4
GAUSSIAN_COC_THRESHOLD_PX = 1.0
GAUSSIAN_SIGMA_PX = 12.0
TRANSITION_COC_LOW = COC_FOCUS_THRESHOLD_PX
TRANSITION_COC_HIGH = GAUSSIAN_COC_THRESHOLD_PX
BOUNDARY_RING_RADIUS_PX = 3
@dataclass
class EdgeArtifactMetrics:
in_focus_mean_abs_diff: float
in_focus_mean_grad_excess: float
transition_mean_abs_diff: float
transition_mean_grad_excess: float
boundary_ring_mean_abs_diff: float
boundary_ring_mean_grad_excess: float
boundary_ring_p95_abs_diff: float
global_mean_grad_excess: float
@dataclass
class ImageBenchmark:
image: str
width: int
height: int
in_focus_fraction: float
transition_fraction: float
boundary_ring_fraction: float
gaussian: EdgeArtifactMetrics
nn_render: EdgeArtifactMetrics
nn_vs_gaussian_improvement: dict[str, float]
def list_cache_jpgs() -> list[Path]:
cache_dir = REPO_ROOT / "cache"
return sorted(p for p in cache_dir.glob("*.jpg") if p.is_file())
def load_checkpoint_from_s3(s3_key: str, map_location: str) -> object:
s3 = boto3.client("s3")
buffer = io.BytesIO()
s3.download_fileobj(S3_BUCKET, s3_key, buffer)
buffer.seek(0)
return torch.load(buffer, map_location=map_location)
def load_renderer(device: str) -> RendererNet:
checkpoint = load_checkpoint_from_s3(RENDERER_KEY, map_location=device)
model = RendererNet(in_channels=6, out_channels=3).to(device)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
else:
model.load_state_dict(checkpoint)
model.eval()
return model
def fit_max_side(rgb: np.ndarray, max_side: int | None) -> np.ndarray:
if max_side is None:
return rgb
h, w = rgb.shape[:2]
longest = max(h, w)
if longest <= max_side:
return rgb
scale = max_side / float(longest)
new_h, new_w = int(round(h * scale)), int(round(w * scale))
return resize(
rgb, (new_h, new_w), anti_aliasing=True, preserve_range=True
).astype(np.float32)
def make_param_maps(size: int) -> tuple[np.ndarray, np.ndarray]:
fstop_map = np.ones((1, size, size), dtype=np.float32) * (F_STOP / F_STOP_MAX)
focal_map = np.ones((1, size, size), dtype=np.float32) * (
FOCAL_LENGTH_MM / FOCAL_LENGTH_MM_MAX
)
return fstop_map, focal_map
def to_chw_resized(rgb: np.ndarray, size: int) -> np.ndarray:
rs = resize(
rgb, (size, size), anti_aliasing=True, preserve_range=True
).astype(np.float32)
return np.transpose(rs, (2, 0, 1))
def coc_px_to_norm_512(coc_px: np.ndarray) -> np.ndarray:
coc_norm = np.clip(coc_px, 0, COC_PX_MAX) / COC_PX_MAX
return resize(
coc_norm,
(TARGET_SIZE, TARGET_SIZE),
order=1,
anti_aliasing=True,
preserve_range=True,
).astype(np.float32)
@torch.no_grad()
def run_renderer(
model: RendererNet,
device: str,
rgb: np.ndarray,
coc_norm_512: np.ndarray,
out_size: tuple[int, int],
) -> np.ndarray:
chw = to_chw_resized(rgb, TARGET_SIZE)
fstop_map, focal_map = make_param_maps(TARGET_SIZE)
coc_channel = coc_norm_512[None, :, :]
x = np.concatenate([chw, fstop_map, focal_map, coc_channel], axis=0)[None]
x = np.nan_to_num(x, nan=0.0, posinf=1.0, neginf=0.0).astype(np.float32)
x = torch.from_numpy(x).to(device)
out = model(x)[0].cpu().numpy()
out = np.clip(np.transpose(out, (1, 2, 0)), 0, 1)
return resize(
out, out_size, anti_aliasing=True, preserve_range=True
).astype(np.float32)
def pseudo_coc_px(
rel_depth: np.ndarray, focus_y: int, focus_x: int
) -> np.ndarray:
coc_px = generate_pseudo_coc_from_relative_depth(
rel_depth,
focus_y,
focus_x,
coc_max_px=COC_MAX_PX,
blur_strength=1.0,
)
return np.clip(coc_px, 0, COC_MAX_PX).astype(np.float32)
def coc_blend_weight(coc_px: np.ndarray) -> np.ndarray:
span = max(COC_MAX_PX - COC_FOCUS_THRESHOLD_PX, 1e-6)
t = np.clip((coc_px - COC_FOCUS_THRESHOLD_PX) / span, 0.0, 1.0)
return (t * t * (3.0 - 2.0 * t)).astype(np.float32)
def render_nn(
model: RendererNet,
device: str,
rgb: np.ndarray,
coc_px: np.ndarray,
) -> np.ndarray:
h, w = rgb.shape[:2]
coc_norm_512 = coc_px_to_norm_512(coc_px)
nn_render = run_renderer(model, device, rgb, coc_norm_512, (h, w))
weight = coc_blend_weight(coc_px)[:, :, None]
blended = (1.0 - weight) * rgb + weight * nn_render
return np.clip(blended, 0, 1).astype(np.float32)
def render_gaussian_background(rgb: np.ndarray, coc_px: np.ndarray) -> np.ndarray:
blurred = gaussian(
rgb, sigma=GAUSSIAN_SIGMA_PX, channel_axis=-1, preserve_range=True
).astype(np.float32)
mask = (coc_px > GAUSSIAN_COC_THRESHOLD_PX)[:, :, None].astype(np.float32)
composite = (1.0 - mask) * rgb + mask * blurred
return np.clip(composite, 0, 1).astype(np.float32)
def rgb_to_luminance(rgb: np.ndarray) -> np.ndarray:
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]).astype(np.float32)
def grad_magnitude(gray: np.ndarray) -> np.ndarray:
return sobel(gray).astype(np.float32)
def region_masks(coc_px: np.ndarray) -> dict[str, np.ndarray]:
in_focus = coc_px <= COC_FOCUS_THRESHOLD_PX
transition = (coc_px > TRANSITION_COC_LOW) & (coc_px <= TRANSITION_COC_HIGH)
gaussian_mask = coc_px > GAUSSIAN_COC_THRESHOLD_PX
footprint = disk(BOUNDARY_RING_RADIUS_PX)
boundary_ring = dilation(gaussian_mask, footprint) ^ erosion(
gaussian_mask, footprint
)
return {
"in_focus": in_focus,
"transition": transition,
"boundary_ring": boundary_ring,
}
def masked_mean(values: np.ndarray, mask: np.ndarray) -> float:
if not np.any(mask):
return 0.0
return float(np.mean(values[mask]))
def masked_p95(values: np.ndarray, mask: np.ndarray) -> float:
if not np.any(mask):
return 0.0
return float(np.percentile(values[mask], 95))
def compute_edge_artifact_metrics(
original: np.ndarray,
processed: np.ndarray,
coc_px: np.ndarray,
) -> EdgeArtifactMetrics:
diff = np.abs(processed - original)
diff_gray = rgb_to_luminance(diff)
orig_gray = rgb_to_luminance(original)
proc_gray = rgb_to_luminance(processed)
grad_excess = np.abs(grad_magnitude(proc_gray) - grad_magnitude(orig_gray))
masks = region_masks(coc_px)
return EdgeArtifactMetrics(
in_focus_mean_abs_diff=masked_mean(diff_gray, masks["in_focus"]),
in_focus_mean_grad_excess=masked_mean(grad_excess, masks["in_focus"]),
transition_mean_abs_diff=masked_mean(diff_gray, masks["transition"]),
transition_mean_grad_excess=masked_mean(grad_excess, masks["transition"]),
boundary_ring_mean_abs_diff=masked_mean(diff_gray, masks["boundary_ring"]),
boundary_ring_mean_grad_excess=masked_mean(
grad_excess, masks["boundary_ring"]
),
boundary_ring_p95_abs_diff=masked_p95(diff_gray, masks["boundary_ring"]),
global_mean_grad_excess=float(np.mean(grad_excess)),
)
def improvement_ratio(gaussian_value: float, nn_value: float) -> float:
if nn_value <= 1e-12:
return float("inf") if gaussian_value > 1e-12 else 1.0
return gaussian_value / nn_value
def benchmark_image(
image_path: Path,
renderer: RendererNet,
depth_model,
device: str,
max_side: int | None,
) -> ImageBenchmark:
rgb = np.array(Image.open(image_path).convert("RGB"), dtype=np.float32) / 255.0
rgb = fit_max_side(rgb, max_side)
h, w = rgb.shape[:2]
rel_depth = predict_relative_depth(str(image_path), depth_model, normalize=True)
rel_depth = resize(
rel_depth, (h, w), order=1, anti_aliasing=True, preserve_range=True
).astype(np.float32)
focus_y, focus_x = h // 2, w // 2
coc_px = pseudo_coc_px(rel_depth, focus_y, focus_x)
nn_image = render_nn(renderer, device, rgb, coc_px)
gaussian_image = render_gaussian_background(rgb, coc_px)
gaussian_metrics = compute_edge_artifact_metrics(rgb, gaussian_image, coc_px)
nn_metrics = compute_edge_artifact_metrics(rgb, nn_image, coc_px)
masks = region_masks(coc_px)
pixel_count = float(h * w)
improvement = {
"boundary_ring_mean_abs_diff": improvement_ratio(
gaussian_metrics.boundary_ring_mean_abs_diff,
nn_metrics.boundary_ring_mean_abs_diff,
),
"boundary_ring_mean_grad_excess": improvement_ratio(
gaussian_metrics.boundary_ring_mean_grad_excess,
nn_metrics.boundary_ring_mean_grad_excess,
),
"transition_mean_grad_excess": improvement_ratio(
gaussian_metrics.transition_mean_grad_excess,
nn_metrics.transition_mean_grad_excess,
),
"in_focus_mean_grad_excess": improvement_ratio(
gaussian_metrics.in_focus_mean_grad_excess,
nn_metrics.in_focus_mean_grad_excess,
),
}
return ImageBenchmark(
image=image_path.name,
width=w,
height=h,
in_focus_fraction=float(np.sum(masks["in_focus"]) / pixel_count),
transition_fraction=float(np.sum(masks["transition"]) / pixel_count),
boundary_ring_fraction=float(np.sum(masks["boundary_ring"]) / pixel_count),
gaussian=gaussian_metrics,
nn_render=nn_metrics,
nn_vs_gaussian_improvement=improvement,
)
def average_metrics(rows: list[ImageBenchmark], method: str) -> EdgeArtifactMetrics:
values = [getattr(row, method) for row in rows]
return EdgeArtifactMetrics(
**{
field: float(np.mean([getattr(v, field) for v in values]))
for field in EdgeArtifactMetrics.__dataclass_fields__
}
)
def print_metric_block(title: str, metrics: EdgeArtifactMetrics) -> None:
print(title)
print(f" in_focus_mean_abs_diff : {metrics.in_focus_mean_abs_diff:.6f}")
print(f" in_focus_mean_grad_excess : {metrics.in_focus_mean_grad_excess:.6f}")
print(f" transition_mean_abs_diff : {metrics.transition_mean_abs_diff:.6f}")
print(f" transition_mean_grad_excess : {metrics.transition_mean_grad_excess:.6f}")
print(f" boundary_ring_mean_abs_diff : {metrics.boundary_ring_mean_abs_diff:.6f}")
print(
f" boundary_ring_mean_grad_excess: {metrics.boundary_ring_mean_grad_excess:.6f}"
)
print(f" boundary_ring_p95_abs_diff : {metrics.boundary_ring_p95_abs_diff:.6f}")
print(f" global_mean_grad_excess : {metrics.global_mean_grad_excess:.6f}")
def write_csv(path: Path, rows: list[ImageBenchmark]) -> None:
metric_fields = list(EdgeArtifactMetrics.__dataclass_fields__)
improvement_fields = [
"boundary_ring_mean_abs_diff",
"boundary_ring_mean_grad_excess",
"transition_mean_grad_excess",
"in_focus_mean_grad_excess",
]
fieldnames = [
"image",
"width",
"height",
"in_focus_fraction",
"transition_fraction",
"boundary_ring_fraction",
]
for prefix in ("gaussian", "nn"):
for field in metric_fields:
fieldnames.append(f"{prefix}_{field}")
for field in improvement_fields:
fieldnames.append(f"improvement_{field}")
with path.open("w", newline="", encoding="utf-8") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames)
writer.writeheader()
for row in rows:
record = {
"image": row.image,
"width": row.width,
"height": row.height,
"in_focus_fraction": row.in_focus_fraction,
"transition_fraction": row.transition_fraction,
"boundary_ring_fraction": row.boundary_ring_fraction,
}
for prefix, metrics in (
("gaussian", row.gaussian),
("nn", row.nn_render),
):
for field in metric_fields:
record[f"{prefix}_{field}"] = getattr(metrics, field)
for field in improvement_fields:
record[f"improvement_{field}"] = row.nn_vs_gaussian_improvement[field]
writer.writerow(record)
def write_json(path: Path, rows: list[ImageBenchmark]) -> None:
payload = {
"config": {
"f_stop": F_STOP,
"focal_length_mm": FOCAL_LENGTH_MM,
"coc_max_px": COC_MAX_PX,
"coc_focus_threshold_px": COC_FOCUS_THRESHOLD_PX,
"gaussian_coc_threshold_px": GAUSSIAN_COC_THRESHOLD_PX,
"gaussian_sigma_px": GAUSSIAN_SIGMA_PX,
"transition_coc_band": [TRANSITION_COC_LOW, TRANSITION_COC_HIGH],
"boundary_ring_radius_px": BOUNDARY_RING_RADIUS_PX,
},
"images": [asdict(row) for row in rows],
"averages": {
"gaussian": asdict(average_metrics(rows, "gaussian")),
"nn_render": asdict(average_metrics(rows, "nn_render")),
},
}
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Benchmark edge artifacts for prototype.py rendering."
)
parser.add_argument(
"--max-side",
type=int,
default=1024,
help="Resize longest image side before inference (default: 1024).",
)
parser.add_argument(
"--output-dir",
type=Path,
default=REPO_ROOT / "quantitative-tests" / "results",
help="Directory for CSV/JSON benchmark outputs.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
images = list_cache_jpgs()
if not images:
raise SystemExit("No JPG files found in cache/.")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"Benchmarking {len(images)} images from cache/")
renderer = load_renderer(device)
depth_model, _depth_device = load_depth_anything(
encoder=DEPTH_ENCODER, checkpoint_path=str(DEPTH_CHECKPOINT)
)
rows: list[ImageBenchmark] = []
for image_path in images:
print(f"\nProcessing {image_path.name} ...")
row = benchmark_image(
image_path,
renderer=renderer,
depth_model=depth_model,
device=device,
max_side=args.max_side,
)
rows.append(row)
print_metric_block(" Gaussian vs original:", row.gaussian)
print_metric_block(" NN render vs original:", row.nn_render)
print(" NN improvement ratios (Gaussian / NN, higher is better):")
for key, value in row.nn_vs_gaussian_improvement.items():
print(f" {key}: {value:.3f}x")
avg_gaussian = average_metrics(rows, "gaussian")
avg_nn = average_metrics(rows, "nn_render")
print("\n" + "=" * 72)
print("Averages across cache JPGs")
print_metric_block("Gaussian vs original:", avg_gaussian)
print_metric_block("NN render vs original:", avg_nn)
args.output_dir.mkdir(parents=True, exist_ok=True)
csv_path = args.output_dir / "edge_artifact_benchmark.csv"
json_path = args.output_dir / "edge_artifact_benchmark.json"
write_csv(csv_path, rows)
write_json(json_path, rows)
print(f"\nWrote {csv_path}")
print(f"Wrote {json_path}")
if __name__ == "__main__":
main()