| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import warnings |
| |
|
| | import hydra |
| | import numpy as np |
| | import torch |
| | from nerf.dataset import get_nerf_datasets, trivial_collate |
| | from nerf.eval_video_utils import generate_eval_video_cameras |
| | from nerf.nerf_renderer import RadianceFieldRenderer |
| | from nerf.stats import Stats |
| | from omegaconf import DictConfig |
| | from PIL import Image |
| |
|
| |
|
| | CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs") |
| |
|
| |
|
| | @hydra.main(config_path=CONFIG_DIR, config_name="lego") |
| | def main(cfg: DictConfig): |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | device = "cuda" |
| | else: |
| | warnings.warn( |
| | "Please note that although executing on CPU is supported," |
| | + "the testing is unlikely to finish in reasonable time." |
| | ) |
| | device = "cpu" |
| |
|
| | |
| | model = RadianceFieldRenderer( |
| | image_size=cfg.data.image_size, |
| | n_pts_per_ray=cfg.raysampler.n_pts_per_ray, |
| | n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray, |
| | n_rays_per_image=cfg.raysampler.n_rays_per_image, |
| | min_depth=cfg.raysampler.min_depth, |
| | max_depth=cfg.raysampler.max_depth, |
| | stratified=cfg.raysampler.stratified, |
| | stratified_test=cfg.raysampler.stratified_test, |
| | chunk_size_test=cfg.raysampler.chunk_size_test, |
| | n_harmonic_functions_xyz=cfg.implicit_function.n_harmonic_functions_xyz, |
| | n_harmonic_functions_dir=cfg.implicit_function.n_harmonic_functions_dir, |
| | n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz, |
| | n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir, |
| | n_layers_xyz=cfg.implicit_function.n_layers_xyz, |
| | density_noise_std=cfg.implicit_function.density_noise_std, |
| | ) |
| |
|
| | |
| | model.to(device) |
| |
|
| | |
| | checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path) |
| | if not os.path.isfile(checkpoint_path): |
| | raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!") |
| |
|
| | print(f"Loading checkpoint {checkpoint_path}.") |
| | loaded_data = torch.load(checkpoint_path) |
| | |
| | |
| | state_dict = { |
| | k: v |
| | for k, v in loaded_data["model"].items() |
| | if "_grid_raysampler._xy_grid" not in k |
| | } |
| | model.load_state_dict(state_dict, strict=False) |
| |
|
| | |
| | if cfg.test.mode == "evaluation": |
| | _, _, test_dataset = get_nerf_datasets( |
| | dataset_name=cfg.data.dataset_name, |
| | image_size=cfg.data.image_size, |
| | ) |
| | elif cfg.test.mode == "export_video": |
| | train_dataset, _, _ = get_nerf_datasets( |
| | dataset_name=cfg.data.dataset_name, |
| | image_size=cfg.data.image_size, |
| | ) |
| | test_dataset = generate_eval_video_cameras( |
| | train_dataset, |
| | trajectory_type=cfg.test.trajectory_type, |
| | up=cfg.test.up, |
| | scene_center=cfg.test.scene_center, |
| | n_eval_cams=cfg.test.n_frames, |
| | trajectory_scale=cfg.test.trajectory_scale, |
| | ) |
| | |
| | export_dir = os.path.splitext(checkpoint_path)[0] + "_video" |
| | os.makedirs(export_dir, exist_ok=True) |
| | else: |
| | raise ValueError(f"Unknown test mode {cfg.test_mode}.") |
| |
|
| | |
| | test_dataloader = torch.utils.data.DataLoader( |
| | test_dataset, |
| | batch_size=1, |
| | shuffle=False, |
| | num_workers=0, |
| | collate_fn=trivial_collate, |
| | ) |
| |
|
| | if cfg.test.mode == "evaluation": |
| | |
| | eval_stats = ["mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"] |
| | stats = Stats(eval_stats) |
| | stats.new_epoch() |
| | elif cfg.test.mode == "export_video": |
| | |
| | frame_paths = [] |
| |
|
| | |
| | model.eval() |
| |
|
| | |
| | for batch_idx, test_batch in enumerate(test_dataloader): |
| | test_image, test_camera, camera_idx = test_batch[0].values() |
| | if test_image is not None: |
| | test_image = test_image.to(device) |
| | test_camera = test_camera.to(device) |
| |
|
| | |
| | model.eval() |
| | with torch.no_grad(): |
| | test_nerf_out, test_metrics = model( |
| | None, |
| | test_camera, |
| | test_image, |
| | ) |
| |
|
| | if cfg.test.mode == "evaluation": |
| | |
| | stats.update(test_metrics, stat_set="test") |
| | stats.print(stat_set="test") |
| |
|
| | elif cfg.test.mode == "export_video": |
| | |
| | frame = test_nerf_out["rgb_fine"][0].detach().cpu() |
| | frame_path = os.path.join(export_dir, f"frame_{batch_idx:05d}.png") |
| | print(f"Writing {frame_path}.") |
| | Image.fromarray((frame.numpy() * 255.0).astype(np.uint8)).save(frame_path) |
| | frame_paths.append(frame_path) |
| |
|
| | if cfg.test.mode == "evaluation": |
| | print(f"Final evaluation metrics on '{cfg.data.dataset_name}':") |
| | for stat in eval_stats: |
| | stat_value = stats.stats["test"][stat].get_epoch_averages()[0] |
| | print(f"{stat:15s}: {stat_value:1.4f}") |
| |
|
| | elif cfg.test.mode == "export_video": |
| | |
| | video_path = os.path.join(export_dir, "video.mp4") |
| | ffmpeg_bin = "ffmpeg" |
| | frame_regexp = os.path.join(export_dir, "frame_%05d.png") |
| | ffmcmd = ( |
| | "%s -r %d -i %s -vcodec h264 -f mp4 -y -b 2000k -pix_fmt yuv420p %s" |
| | % (ffmpeg_bin, cfg.test.fps, frame_regexp, video_path) |
| | ) |
| | ret = os.system(ffmcmd) |
| | if ret != 0: |
| | raise RuntimeError("ffmpeg failed!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|