File size: 6,180 Bytes
2c76547 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from collections import defaultdict
import torch.nn.functional as F
import torch
from tqdm import tqdm
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import Configurable
from evaluation.utils.eval_utils import depth2disparity_scale, eval_batch
from evaluation.utils.utils import (
PerceptionPrediction,
pretty_print_perception_metrics,
visualize_batch,
)
class Evaluator(Configurable):
"""
A class defining the DynamicStereo evaluator.
Args:
eps: Threshold for converting disparity to depth.
"""
eps = 1e-5
def setup_visualization(self, cfg: DictConfig) -> None:
# Visualization
self.visualize_interval = cfg.visualize_interval
self.exp_dir = cfg.exp_dir
if self.visualize_interval > 0:
self.visualize_dir = os.path.join(cfg.exp_dir, "visualisations")
@torch.no_grad()
def evaluate_sequence(
self,
sci_enc_L,
sci_enc_R,
model,
test_dataloader: torch.utils.data.DataLoader,
is_real_data: bool = False,
step=None,
writer=None,
train_mode=False,
interp_shape=None,
resolution=[480, 640]
):
# -- Modified by Chu King on 20th November 2025 for SCI Stereo.
# -- model.eval()
per_batch_eval_results = []
if self.visualize_interval > 0:
os.makedirs(self.visualize_dir, exist_ok=True)
for batch_idx, sequence in enumerate(tqdm(test_dataloader)):
batch_dict = defaultdict(list)
batch_dict["stereo_video"] = sequence["img"]
if not is_real_data:
batch_dict["disparity"] = sequence["disp"][:, 0].abs()
batch_dict["disparity_mask"] = sequence["valid_disp"][:, :1] # ~ (T, 1, 720, 1280)
if "mask" in sequence:
batch_dict["fg_mask"] = sequence["mask"][:, :1]
else:
batch_dict["fg_mask"] = torch.ones_like(
batch_dict["disparity_mask"]
)
elif interp_shape is not None:
left_video = batch_dict["stereo_video"][:, 0]
left_video = F.interpolate(
left_video, tuple(interp_shape), mode="bilinear"
)
right_video = batch_dict["stereo_video"][:, 1]
right_video = F.interpolate(
right_video, tuple(interp_shape), mode="bilinear"
)
batch_dict["stereo_video"] = torch.stack([left_video, right_video], 1)
# -- This method is always invoked with train_mode=True.
if train_mode:
# -- Modified by Chu King on 20th November 2025.
# -- predictions = model.forward_batch_test(batch_dict)
predictions = model.forward_batch_test(batch_dict, sci_enc_L, sci_enc_R)
else:
predictions = model(batch_dict)
assert "disparity" in predictions
predictions["disparity"] = predictions["disparity"][:, :1].clone().cpu()
# -- print ("[INFO] predictions[\"disparity\"].shape", predictions["disparity"].shape)
# -- print ("[INFO] batch_dict[\"disparity_mask\"][..., :resolution[0], :resolution[1]].shape", batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].shape)
# -- print ("[INFO] batch_dict[\"disparity_mask\"][..., :resolution[0], :resolution[1]].round().shape", batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].round().shape)
if not is_real_data:
predictions["disparity"] = predictions["disparity"] * (
# -- Modified by Chu King on 22nd November 2025
# -- batch_dict["disparity_mask"].round()
batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].round()
)
batch_eval_result, seq_length = eval_batch(batch_dict, predictions)
per_batch_eval_results.append((batch_eval_result, seq_length))
pretty_print_perception_metrics(batch_eval_result)
if (self.visualize_interval > 0) and (
batch_idx % self.visualize_interval == 0
):
perception_prediction = PerceptionPrediction()
pred_disp = predictions["disparity"]
pred_disp[pred_disp < self.eps] = self.eps
scale = depth2disparity_scale(
sequence["viewpoint"][0][0],
sequence["viewpoint"][0][1],
torch.tensor([pred_disp.shape[2], pred_disp.shape[3]])[None],
)
perception_prediction.depth_map = (scale / pred_disp).cuda()
perspective_cameras = []
for cam in sequence["viewpoint"]:
perspective_cameras.append(cam[0])
perception_prediction.perspective_cameras = perspective_cameras
# -- Modified by Chu King on 22nd November 2025 to fix image resolution during training.
if "stereo_original_video" in batch_dict:
batch_dict["stereo_video"] = batch_dict["stereo_original_video"][..., :resolution[0], :resolution[1]].clone()
for k, v in batch_dict.items():
if isinstance(v, torch.Tensor):
batch_dict[k] = v.cuda()
visualize_batch(
batch_dict,
perception_prediction,
output_dir=self.visualize_dir,
sequence_name=sequence["metadata"][0][0][0],
step=step,
writer=writer,
# -- Added by Chu King on 22nd November 2025 to fix image resolution during evaluation.
resolution=resolution
)
return per_batch_eval_results
|