Commit ·
e0700fc
1
Parent(s): 4ba91b5
Commite evaluate.py
Browse files- evaluation/evaluate.py +169 -126
evaluation/evaluate.py
CHANGED
|
@@ -4,140 +4,183 @@
|
|
| 4 |
# This source code is licensed under the license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
-
import
|
| 8 |
-
|
| 9 |
-
from dataclasses import dataclass, field
|
| 10 |
-
from typing import Any, Dict, Optional
|
| 11 |
-
|
| 12 |
-
import hydra
|
| 13 |
-
import numpy as np
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
import torch
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
from dynamic_stereo.evaluation.utils.utils import aggregate_and_print_results
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
from
|
| 23 |
-
|
| 24 |
-
|
|
|
|
| 25 |
)
|
| 26 |
-
from
|
| 27 |
-
from
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
)
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
dstype=cfg.dstype,
|
| 91 |
-
sample_len=cfg.sample_len,
|
| 92 |
-
add_monkaa=False,
|
| 93 |
-
add_driving=False,
|
| 94 |
-
things_test=True,
|
| 95 |
-
)
|
| 96 |
-
elif cfg.dataset_name == "real":
|
| 97 |
-
for real_sequence_name in ["teddy_static", "ignacio_waving", "nikita_reading"]:
|
| 98 |
-
ds_path = f"./dynamic_replica_data/real/{real_sequence_name}"
|
| 99 |
-
# seq_len_real = 20
|
| 100 |
-
real_dataset = datasets.DynamicReplicaDataset(
|
| 101 |
-
split="test",
|
| 102 |
-
sample_len=cfg.sample_len,
|
| 103 |
-
root=ds_path,
|
| 104 |
-
only_first_n_samples=1,
|
| 105 |
)
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
return
|
| 114 |
-
|
| 115 |
-
print()
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
-
aggreegate_result = aggregate_and_print_results(evaluate_result)
|
| 123 |
-
|
| 124 |
-
result_file = os.path.join(cfg.exp_dir, f"result_eval.json")
|
| 125 |
-
|
| 126 |
-
print(f"Dumping eval results to {result_file}.")
|
| 127 |
-
with open(result_file, "w") as f:
|
| 128 |
-
json.dump(aggreegate_result, f)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
cs = hydra.core.config_store.ConfigStore.instance()
|
| 132 |
-
cs.store(name="default_config_eval", node=DefaultConfig)
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
@hydra.main(config_path="./configs/", config_name="default_config_eval")
|
| 136 |
-
def evaluate(cfg: DefaultConfig) -> None:
|
| 137 |
-
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 138 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
| 139 |
-
run_eval(cfg)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
if __name__ == "__main__":
|
| 143 |
-
evaluate()
|
|
|
|
| 4 |
# This source code is licensed under the license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
+
import sys
|
| 8 |
+
sys.path.append("../")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
import argparse
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import os
|
| 15 |
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.optim as optim
|
|
|
|
| 18 |
|
| 19 |
+
from munch import DefaultMunch
|
| 20 |
+
import json
|
| 21 |
+
from pytorch_lightning.lite import LightningLite
|
| 22 |
+
from torch.cuda.amp import GradScaler
|
| 23 |
|
| 24 |
+
from train_utils.utils import (
|
| 25 |
+
run_test_eval,
|
| 26 |
+
save_ims_to_tb,
|
| 27 |
+
count_parameters,
|
| 28 |
)
|
| 29 |
+
from train_utils.logger import Logger
|
| 30 |
+
from models.core.dynamic_stereo import DynamicStereo
|
| 31 |
+
from models.core.sci_codec import sci_encoder
|
| 32 |
+
from evaluation.core.evaluator import Evaluator
|
| 33 |
+
from train_utils.losses import sequence_loss
|
| 34 |
+
import datasets.dynamic_stereo_datasets as datasets
|
| 35 |
+
|
| 36 |
+
class wrapper(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
sigma_range=[0, 1e-9],
|
| 40 |
+
num_frames=8,
|
| 41 |
+
in_channels=1,
|
| 42 |
+
n_taps=2,
|
| 43 |
+
resolution=[480, 640],
|
| 44 |
+
mixed_precision=True,
|
| 45 |
+
attention_type="self_stereo_temporal_update_time_update_space",
|
| 46 |
+
update_block_3d=True,
|
| 47 |
+
different_update_blocks=True,
|
| 48 |
+
train_iters=16):
|
| 49 |
+
|
| 50 |
+
super(wrapper, self).__init__()
|
| 51 |
+
|
| 52 |
+
self.train_iters = train_iters
|
| 53 |
+
|
| 54 |
+
self.sci_enc_L = sci_encoder(sigma_range=sigma_range,
|
| 55 |
+
n_frame=num_frames,
|
| 56 |
+
in_channels=in_channels,
|
| 57 |
+
n_taps=n_taps,
|
| 58 |
+
resolution=resolution)
|
| 59 |
+
self.sci_enc_R = sci_encoder(sigma_range=sigma_range,
|
| 60 |
+
n_frame=num_frames,
|
| 61 |
+
in_channels=in_channels,
|
| 62 |
+
n_taps=n_taps,
|
| 63 |
+
resolution=resolution)
|
| 64 |
+
|
| 65 |
+
self.stereo = DynamicStereo(max_disp=256,
|
| 66 |
+
mixed_precision=mixed_precision,
|
| 67 |
+
num_frames=num_frames,
|
| 68 |
+
attention_type=attention_type,
|
| 69 |
+
use_3d_update_block=update_block_3d,
|
| 70 |
+
different_update_blocks=different_update_blocks)
|
| 71 |
+
|
| 72 |
+
def forward(self, batch):
|
| 73 |
+
# ---- ---- FORWARD PASS ---- ----
|
| 74 |
+
# -- Modified by Chu King on 20th November 2025
|
| 75 |
+
|
| 76 |
+
# -- print ("[INFO] batch[\"img\"].device: ", batch["img"].device)
|
| 77 |
+
|
| 78 |
+
# 0) Convert to Gray
|
| 79 |
+
def rgb_to_gray(x):
|
| 80 |
+
weights = torch.tensor([0.2989, 0.5870, 0.1140], dtype=x.dtype, device=x.device)
|
| 81 |
+
gray = (x * weights[None, None, :, None, None]).sum(dim=2)
|
| 82 |
+
return gray # -- shape: [B, T, H, W]
|
| 83 |
+
|
| 84 |
+
video_L = rgb_to_gray(batch["img"][:, :, 0]) # ~ (b, t, h, w)
|
| 85 |
+
video_R = rgb_to_gray(batch["img"][:, :, 1]) # ~ (b, t, h, w)
|
| 86 |
+
|
| 87 |
+
# -- print ("[INFO] video_L.device: ", video_L.device)
|
| 88 |
+
|
| 89 |
+
# 1) Extract and normalize input videos.
|
| 90 |
+
# -- min_max_norm = lambda x : 2. * (x / 255.) - 1.
|
| 91 |
+
min_max_norm = lambda x: x / 255.
|
| 92 |
+
video_L = min_max_norm(video_L) # ~ (b, t, h, w)
|
| 93 |
+
video_R = min_max_norm(video_R) # ~ (b, t, h, w)
|
| 94 |
+
# -- print ("[INFO] video_L.device: ", video_L.device)
|
| 95 |
+
|
| 96 |
+
# 2) If the tensor is non-contiguous and we try .view() later, PyTorch will raise an error:
|
| 97 |
+
video_L = video_L.contiguous()
|
| 98 |
+
video_R = video_R.contiguous()
|
| 99 |
+
|
| 100 |
+
# -- print ("[INFO] video_L.device: ", video_L.device)
|
| 101 |
+
|
| 102 |
+
# 3) Coded exposure modeling.
|
| 103 |
+
snapshot_L = self.sci_enc_L(video_L) # ~ (b, c, h, w) -- c=2 for 2 taps
|
| 104 |
+
snapshot_R = self.sci_enc_R(video_R) # ~ (b, c, h, w) -- c=2 for 2 taps
|
| 105 |
+
|
| 106 |
+
# -- print ("[INFO] self.sci_enc_L.device: ", next(self.sci_enc_R.parameters()).device)
|
| 107 |
+
# -- print ("[INFO] snapshot_L.device: ", snapshot_L.device)
|
| 108 |
+
|
| 109 |
+
# 4) Dynamic Stereo
|
| 110 |
+
output = {}
|
| 111 |
+
|
| 112 |
+
disparities = self.stereo(
|
| 113 |
+
snapshot_L,
|
| 114 |
+
snapshot_R,
|
| 115 |
+
iters=self.train_iters,
|
| 116 |
+
test_mode=False
|
| 117 |
)
|
| 118 |
+
|
| 119 |
+
n_views = len(batch["disp"][0]) # -- sample_len
|
| 120 |
+
for i in range(n_views):
|
| 121 |
+
seq_loss, metrics = sequence_loss(
|
| 122 |
+
disparities[:, i], batch["disp"][:, i, 0], batch["valid_disp"][:, i, 0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
|
| 125 |
+
output[f"disp_{i}"] = {"loss": seq_loss / n_views, "metrics": metrics}
|
| 126 |
+
output["disparity"] = {
|
| 127 |
+
"predictions": torch.cat(
|
| 128 |
+
[disparities[-1, i, 0] for i in range(n_views)], dim=1
|
| 129 |
+
).detach(),
|
| 130 |
+
}
|
| 131 |
+
return output
|
|
|
|
|
|
|
| 132 |
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
eval_dataloader_dr = datasets.DynamicReplicaDataset(
|
| 135 |
+
split="valid", sample_len=8, only_first_n_samples=1, VERBOSE=False, root="../dynamic_replica_data", t_step_validation=4
|
| 136 |
)
|
| 137 |
+
|
| 138 |
+
eval_dataloader_sintel_clean = datasets.SequenceSintelStereo(dstype="clean")
|
| 139 |
+
eval_dataloader_sintel_final = datasets.SequenceSintelStereo(dstype="final")
|
| 140 |
+
|
| 141 |
+
eval_dataloaders = [
|
| 142 |
+
("sintel_clean", eval_dataloader_sintel_clean),
|
| 143 |
+
("sintel_final", eval_dataloader_sintel_final),
|
| 144 |
+
("dynamic_replica", eval_dataloader_dr),
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
evaluator = Evaluator()
|
| 148 |
+
|
| 149 |
+
eval_vis_cfg = {
|
| 150 |
+
"visualize_interval": 1, # Use 0 for no visualization
|
| 151 |
+
"exp_dir": "./"
|
| 152 |
+
}
|
| 153 |
+
eval_vis_cfg = DefaultMunch.fromDict(eval_vis_cfg, object())
|
| 154 |
+
evaluator.setup_visualization(eval_vis_cfg)
|
| 155 |
+
|
| 156 |
+
# ----------------------------------------- Model Instantiation -----------------------------------------------
|
| 157 |
+
model = wrapper(sigma_range=[0, 1e-9],
|
| 158 |
+
num_frames=8,
|
| 159 |
+
in_channels=1,
|
| 160 |
+
n_taps=2,
|
| 161 |
+
resolution=[480, 640],
|
| 162 |
+
mixed_precision=True,
|
| 163 |
+
attention_type="self_stereo_temporal_update_time_update_space",
|
| 164 |
+
update_block_3d=True,
|
| 165 |
+
different_update_blocks=True,
|
| 166 |
+
train_iters=8)
|
| 167 |
+
|
| 168 |
+
ckpt_path = "../dynamicstereo_sf_dr/model_dynamic-stereo_050895.pth"
|
| 169 |
+
state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))
|
| 170 |
+
model.load_state_dict(state_dict["model"], strict=True)
|
| 171 |
+
model.eval()
|
| 172 |
+
|
| 173 |
+
run_test_eval(
|
| 174 |
+
ckpt_path="./",
|
| 175 |
+
eval_type="valid",
|
| 176 |
+
evaluator=evaluator,
|
| 177 |
+
sci_enc_L=model.sci_enc_L,
|
| 178 |
+
sci_enc_R=model.sci_enc_R,
|
| 179 |
+
model=model.stereo,
|
| 180 |
+
dataloaders=eval_dataloaders,
|
| 181 |
+
writer=None,
|
| 182 |
+
step=None,
|
| 183 |
+
resolution=[480, 640]
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|