phiph's picture
Upload folder using huggingface_hub
7382c66 verified
# Authors: Jing He, Haodong Li
# Last modified: 2025-10-10
import os
import torch
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from tabulate import tabulate
from eval.datasets import (
BaseDepthDataset,
get_dataset,
DatasetMode,
get_pred_name
)
from eval.utils import (
MetricTracker,
metric,
init_per_sample_csv,
write_per_sample_csv,
write_metrics_txt,
align_depth_least_square,
align_depth_median
)
def resize_max_res(
img: torch.Tensor,
max_edge_resolution: int,
resample_method: InterpolationMode = InterpolationMode.BILINEAR,
) -> torch.Tensor:
"""
Resize image to limit maximum edge length while keeping aspect ratio.
Args:
img (`torch.Tensor`):
Image tensor to be resized. Expected shape: [B, C, H, W]
max_edge_resolution (`int`):
Maximum edge length (pixel).
resample_method (`PIL.Image.Resampling`):
Resampling method used to resize images.
Returns:
`torch.Tensor`: Resized image.
"""
assert 4 == img.dim(), f"Invalid input shape {img.shape}"
original_height, original_width = img.shape[-2:]
downscale_factor = min(
max_edge_resolution / original_width, max_edge_resolution / original_height
)
new_width = int(original_width * downscale_factor)
new_height = int(original_height * downscale_factor)
resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
return resized_img
def run_evaluation(model, config, dataset_name, output_dir, device):
eval_dir = os.path.join(output_dir, dataset_name)
if not os.path.exists(eval_dir): os.makedirs(eval_dir)
dataset_config = config['evaluation']['datasets'][dataset_name]
model_dtype = config['spherevit']['dtype']
dataset: BaseDepthDataset = get_dataset(dataset_config, dataset_name,
base_data_dir=config['evaluation']['datasets_dir'], mode=DatasetMode.EVAL)
dataloader = DataLoader(dataset, batch_size=1, num_workers=16, pin_memory=True)
metric_funcs = [getattr(metric, _met) for _met in config['evaluation']['metric_names']]
metric_tracker = MetricTracker(*[m.__name__ for m in metric_funcs])
metric_tracker.reset()
alignment = config['evaluation']['alignment']
per_sample_csv = init_per_sample_csv(eval_dir, alignment, metric_funcs)
for data in tqdm(dataloader, desc=f"Evaluating {dataset_name}"):
# GT data
depth_raw_ts = data["depth_raw_linear"].squeeze()
valid_mask_ts = data["valid_mask_raw"].squeeze()
rgb_name = data["rgb_relative_path"][0]
depth_raw = depth_raw_ts.numpy()
valid_mask = valid_mask_ts.numpy()
rgb_tmp = data["rgb_int"].squeeze()
if rgb_tmp.min() == 0:
rgb_tmp = rgb_tmp.float()
rgb_tmp = torch.mean(rgb_tmp, dim=0).cpu().numpy()
zero_mask = rgb_tmp != 0
valid_mask = valid_mask & zero_mask
depth_raw_ts = depth_raw_ts.to(device)
valid_mask_ts = valid_mask_ts.to(device)
# Get predictions
rgb_basename = os.path.basename(rgb_name)
pred_basename = get_pred_name(
rgb_basename, dataset.name_mode, suffix=""
)
pred_name = os.path.join(os.path.dirname(rgb_name), pred_basename)
# resize to processing_res
input_size = data["rgb_int"].shape
input_rgb = resize_max_res(
data["rgb_int"],
max_edge_resolution=1092,
)
input_rgb = input_rgb[0].to(device)
input_rgb = input_rgb / 255.0
input_rgb = input_rgb.to(model_dtype)
depth_pred = model(input_rgb.unsqueeze(0))
depth_pred = depth_pred.unsqueeze(0)
depth_pred = resize(depth_pred, input_size[-2:], antialias=True)
depth_pred = depth_pred.squeeze().cpu().numpy()
if "least_square" == alignment:
depth_pred = np.clip(
depth_pred, a_min=-1e6, a_max=1e6
)
depth_pred, _, _ = align_depth_least_square(
gt_arr=depth_raw,
pred_arr=depth_pred,
valid_mask_arr=valid_mask,
return_scale_shift=True,
max_resolution=dataset_config['alignment_max_res'],
)
elif "median" == alignment:
depth_pred = np.clip(
depth_pred, a_min=-1e6, a_max=1e6
)
depth_pred, _, _ = align_depth_median(
gt_arr=depth_raw,
pred_arr=depth_pred,
valid_mask_arr=valid_mask,
return_scale_shift=True,
)
elif "metric" == alignment:
pass
else:
raise NotImplementedError
# Clip to dataset min max
depth_pred = np.clip(
depth_pred, a_min=dataset.min_depth, a_max=dataset.max_depth
)
# clip to d > 0 for evaluation
depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None)
# Evaluate (using CUDA if available)
sample_metric = []
depth_pred_ts = torch.from_numpy(depth_pred).to(device)
for met_func in metric_funcs:
_metric_name = met_func.__name__
_metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item()
sample_metric.append(_metric.__str__())
metric_tracker.update(_metric_name, _metric)
write_per_sample_csv(per_sample_csv, pred_name, sample_metric)
# -------------------- Save metrics to file --------------------
eval_text = tabulate(
[metric_tracker.result().keys(), metric_tracker.result().values()]
)
write_metrics_txt(eval_dir, alignment, eval_text)
return metric_tracker.result()