# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os from collections import defaultdict import torch.nn.functional as F import torch from tqdm import tqdm from omegaconf import DictConfig from pytorch3d.implicitron.tools.config import Configurable from evaluation.utils.eval_utils import depth2disparity_scale, eval_batch from evaluation.utils.utils import ( PerceptionPrediction, pretty_print_perception_metrics, visualize_batch, ) class Evaluator(Configurable): """ A class defining the DynamicStereo evaluator. Args: eps: Threshold for converting disparity to depth. """ eps = 1e-5 def setup_visualization(self, cfg: DictConfig) -> None: # Visualization self.visualize_interval = cfg.visualize_interval self.exp_dir = cfg.exp_dir if self.visualize_interval > 0: self.visualize_dir = os.path.join(cfg.exp_dir, "visualisations") @torch.no_grad() def evaluate_sequence( self, sci_enc_L, sci_enc_R, model, test_dataloader: torch.utils.data.DataLoader, is_real_data: bool = False, step=None, writer=None, train_mode=False, interp_shape=None, resolution=[480, 640] ): # -- Modified by Chu King on 20th November 2025 for SCI Stereo. # -- model.eval() per_batch_eval_results = [] if self.visualize_interval > 0: os.makedirs(self.visualize_dir, exist_ok=True) for batch_idx, sequence in enumerate(tqdm(test_dataloader)): batch_dict = defaultdict(list) batch_dict["stereo_video"] = sequence["img"] if not is_real_data: batch_dict["disparity"] = sequence["disp"][:, 0].abs() batch_dict["disparity_mask"] = sequence["valid_disp"][:, :1] # ~ (T, 1, 720, 1280) if "mask" in sequence: batch_dict["fg_mask"] = sequence["mask"][:, :1] else: batch_dict["fg_mask"] = torch.ones_like( batch_dict["disparity_mask"] ) elif interp_shape is not None: left_video = batch_dict["stereo_video"][:, 0] left_video = F.interpolate( left_video, tuple(interp_shape), mode="bilinear" ) right_video = batch_dict["stereo_video"][:, 1] right_video = F.interpolate( right_video, tuple(interp_shape), mode="bilinear" ) batch_dict["stereo_video"] = torch.stack([left_video, right_video], 1) # -- This method is always invoked with train_mode=True. if train_mode: # -- Modified by Chu King on 20th November 2025. # -- predictions = model.forward_batch_test(batch_dict) predictions = model.forward_batch_test(batch_dict, sci_enc_L, sci_enc_R) else: predictions = model(batch_dict) assert "disparity" in predictions predictions["disparity"] = predictions["disparity"][:, :1].clone().cpu() # -- print ("[INFO] predictions[\"disparity\"].shape", predictions["disparity"].shape) # -- print ("[INFO] batch_dict[\"disparity_mask\"][..., :resolution[0], :resolution[1]].shape", batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].shape) # -- print ("[INFO] batch_dict[\"disparity_mask\"][..., :resolution[0], :resolution[1]].round().shape", batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].round().shape) if not is_real_data: predictions["disparity"] = predictions["disparity"] * ( # -- Modified by Chu King on 22nd November 2025 # -- batch_dict["disparity_mask"].round() batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].round() ) batch_eval_result, seq_length = eval_batch(batch_dict, predictions) per_batch_eval_results.append((batch_eval_result, seq_length)) pretty_print_perception_metrics(batch_eval_result) if (self.visualize_interval > 0) and ( batch_idx % self.visualize_interval == 0 ): perception_prediction = PerceptionPrediction() pred_disp = predictions["disparity"] pred_disp[pred_disp < self.eps] = self.eps scale = depth2disparity_scale( sequence["viewpoint"][0][0], sequence["viewpoint"][0][1], torch.tensor([pred_disp.shape[2], pred_disp.shape[3]])[None], ) perception_prediction.depth_map = (scale / pred_disp).cuda() perspective_cameras = [] for cam in sequence["viewpoint"]: perspective_cameras.append(cam[0]) perception_prediction.perspective_cameras = perspective_cameras # -- Modified by Chu King on 22nd November 2025 to fix image resolution during training. if "stereo_original_video" in batch_dict: batch_dict["stereo_video"] = batch_dict["stereo_original_video"][..., :resolution[0], :resolution[1]].clone() for k, v in batch_dict.items(): if isinstance(v, torch.Tensor): batch_dict[k] = v.cuda() visualize_batch( batch_dict, perception_prediction, output_dir=self.visualize_dir, sequence_name=sequence["metadata"][0][0][0], step=step, writer=writer, # -- Added by Chu King on 22nd November 2025 to fix image resolution during evaluation. resolution=resolution ) return per_batch_eval_results