File size: 6,180 Bytes
2c76547
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
from collections import defaultdict
import torch.nn.functional as F
import torch
from tqdm import tqdm
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import Configurable

from evaluation.utils.eval_utils import depth2disparity_scale, eval_batch
from evaluation.utils.utils import (
    PerceptionPrediction,
    pretty_print_perception_metrics,
    visualize_batch,
)


class Evaluator(Configurable):
    """
    A class defining the DynamicStereo evaluator.

    Args:
        eps: Threshold for converting disparity to depth.
    """

    eps = 1e-5

    def setup_visualization(self, cfg: DictConfig) -> None:
        # Visualization
        self.visualize_interval = cfg.visualize_interval
        self.exp_dir = cfg.exp_dir
        if self.visualize_interval > 0:
            self.visualize_dir = os.path.join(cfg.exp_dir, "visualisations")

    @torch.no_grad()
    def evaluate_sequence(
        self,
        sci_enc_L,
        sci_enc_R,
        model,
        test_dataloader: torch.utils.data.DataLoader,
        is_real_data: bool = False,
        step=None,
        writer=None,
        train_mode=False,
        interp_shape=None,
        resolution=[480, 640]
    ):
        # -- Modified by Chu King on 20th November 2025 for SCI Stereo.
        # -- model.eval()

        per_batch_eval_results = []

        if self.visualize_interval > 0:
            os.makedirs(self.visualize_dir, exist_ok=True)

        for batch_idx, sequence in enumerate(tqdm(test_dataloader)):
            batch_dict = defaultdict(list)
            batch_dict["stereo_video"] = sequence["img"]
            if not is_real_data:
                batch_dict["disparity"] = sequence["disp"][:, 0].abs()
                batch_dict["disparity_mask"] = sequence["valid_disp"][:, :1] # ~ (T, 1, 720, 1280)

                if "mask" in sequence:
                    batch_dict["fg_mask"] = sequence["mask"][:, :1]
                else:
                    batch_dict["fg_mask"] = torch.ones_like(
                        batch_dict["disparity_mask"]
                    )
            elif interp_shape is not None:
                left_video = batch_dict["stereo_video"][:, 0]
                left_video = F.interpolate(
                    left_video, tuple(interp_shape), mode="bilinear"
                )
                right_video = batch_dict["stereo_video"][:, 1]
                right_video = F.interpolate(
                    right_video, tuple(interp_shape), mode="bilinear"
                )
                batch_dict["stereo_video"] = torch.stack([left_video, right_video], 1)

            # -- This method is always invoked with train_mode=True.
            if train_mode:
                # -- Modified by Chu King on 20th November 2025.
                # -- predictions = model.forward_batch_test(batch_dict)
                predictions = model.forward_batch_test(batch_dict, sci_enc_L, sci_enc_R)
            else:
                predictions = model(batch_dict)

            assert "disparity" in predictions
            predictions["disparity"] = predictions["disparity"][:, :1].clone().cpu()

            # -- print ("[INFO] predictions[\"disparity\"].shape", predictions["disparity"].shape)
            # -- print ("[INFO] batch_dict[\"disparity_mask\"][..., :resolution[0], :resolution[1]].shape", batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].shape)
            # -- print ("[INFO] batch_dict[\"disparity_mask\"][..., :resolution[0], :resolution[1]].round().shape", batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].round().shape)

            if not is_real_data:
                predictions["disparity"] = predictions["disparity"] * (
                    # -- Modified by Chu King on 22nd November 2025
                    # -- batch_dict["disparity_mask"].round()
                    batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].round()
                )

                batch_eval_result, seq_length = eval_batch(batch_dict, predictions)

                per_batch_eval_results.append((batch_eval_result, seq_length))
                pretty_print_perception_metrics(batch_eval_result)

            if (self.visualize_interval > 0) and (
                batch_idx % self.visualize_interval == 0
            ):
                perception_prediction = PerceptionPrediction()

                pred_disp = predictions["disparity"]
                pred_disp[pred_disp < self.eps] = self.eps

                scale = depth2disparity_scale(
                    sequence["viewpoint"][0][0],
                    sequence["viewpoint"][0][1],
                    torch.tensor([pred_disp.shape[2], pred_disp.shape[3]])[None],
                )

                perception_prediction.depth_map = (scale / pred_disp).cuda()
                perspective_cameras = []
                for cam in sequence["viewpoint"]:
                    perspective_cameras.append(cam[0])

                perception_prediction.perspective_cameras = perspective_cameras

                # -- Modified by Chu King on 22nd November 2025 to fix image resolution during training.
                if "stereo_original_video" in batch_dict:
                    batch_dict["stereo_video"] = batch_dict["stereo_original_video"][..., :resolution[0], :resolution[1]].clone()

                for k, v in batch_dict.items():
                    if isinstance(v, torch.Tensor):
                        batch_dict[k] = v.cuda()

                visualize_batch(
                    batch_dict,
                    perception_prediction,
                    output_dir=self.visualize_dir,
                    sequence_name=sequence["metadata"][0][0][0],
                    step=step,
                    writer=writer,
                    # -- Added by Chu King on 22nd November 2025 to fix image resolution during evaluation.
                    resolution=resolution
                )
        return per_batch_eval_results