depthsplat / src /evaluation /metric_computer.py
Yeqing0814's picture
Upload folder using huggingface_hub
a6dd040 verified
import os
from pathlib import Path
import torch
from pytorch_lightning import LightningModule
from tabulate import tabulate
from ..misc.image_io import load_image, save_image
from ..visualization.annotation import add_label
from ..visualization.layout import add_border, hcat
from .evaluation_cfg import EvaluationCfg
from .metrics import compute_lpips, compute_psnr, compute_ssim
class MetricComputer(LightningModule):
cfg: EvaluationCfg
def __init__(self, cfg: EvaluationCfg) -> None:
super().__init__()
self.cfg = cfg
def test_step(self, batch, batch_idx):
scene = batch["scene"][0]
b, cv, _, _, _ = batch["context"]["image"].shape
assert b == 1 and cv == 2
_, v, _, _, _ = batch["target"]["image"].shape
# Skip scenes.
for method in self.cfg.methods:
if not (method.path / scene).exists():
print(f'Skipping "{scene}".')
return
# Load the images.
all_images = {}
try:
for method in self.cfg.methods:
images = [
load_image(method.path / scene / f"color/{index.item():0>6}.png")
for index in batch["target"]["index"][0]
]
all_images[method.key] = torch.stack(images).to(self.device)
except FileNotFoundError:
print(f'Skipping "{scene}".')
return
# Compute metrics.
all_metrics = {}
rgb_gt = batch["target"]["image"][0]
for key, images in all_images.items():
all_metrics = {
**all_metrics,
f"lpips_{key}": compute_lpips(rgb_gt, images).mean(),
f"ssim_{key}": compute_ssim(rgb_gt, images).mean(),
f"psnr_{key}": compute_psnr(rgb_gt, images).mean(),
}
self.log_dict(all_metrics)
self.print_preview_metrics(all_metrics)
# Skip the rest if no side-by-side is needed.
if self.cfg.side_by_side_path is None:
return
# Create side-by-side.
scene_key = f"{batch_idx:0>6}_{scene}"
for i in range(v):
true_index = batch["target"]["index"][0, i]
row = [add_label(batch["target"]["image"][0, i], "Ground Truth")]
for method in self.cfg.methods:
image = all_images[method.key][i]
image = add_label(image, method.name)
row.append(image)
start_frame = batch["target"]["index"][0, 0]
end_frame = batch["target"]["index"][0, -1]
label = f"Scene {batch['scene'][0]} (frames {start_frame} to {end_frame})"
row = add_border(add_label(hcat(*row), label, font_size=16))
save_image(
row,
self.cfg.side_by_side_path / scene_key / f"{true_index:0>6}.png",
)
# Create an animation.
if self.cfg.animate_side_by_side:
(self.cfg.side_by_side_path / "videos").mkdir(exist_ok=True, parents=True)
command = (
'ffmpeg -y -framerate 30 -pattern_type glob -i "*.png" -c:v libx264 '
'-pix_fmt yuv420p -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2"'
)
os.system(
f"cd {self.cfg.side_by_side_path / scene_key} && {command} "
f"{Path.cwd()}/{self.cfg.side_by_side_path}/videos/{scene_key}.mp4"
)
def print_preview_metrics(self, metrics: dict[str, float]) -> None:
if getattr(self, "running_metrics", None) is None:
self.running_metrics = metrics
self.running_metric_steps = 1
else:
s = self.running_metric_steps
self.running_metrics = {
k: ((s * v) + metrics[k]) / (s + 1)
for k, v in self.running_metrics.items()
}
self.running_metric_steps += 1
table = []
for method in self.cfg.methods:
row = [
f"{self.running_metrics[f'{metric}_{method.key}']:.3f}"
for metric in ("psnr", "lpips", "ssim")
]
table.append((method.key, *row))
table = tabulate(table, ["Method", "PSNR (dB)", "LPIPS", "SSIM"])
print(table)