| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import collections |
| | import os |
| | import pickle |
| | import warnings |
| |
|
| | import hydra |
| | import numpy as np |
| | import torch |
| | from nerf.dataset import get_nerf_datasets, trivial_collate |
| | from nerf.nerf_renderer import RadianceFieldRenderer, visualize_nerf_outputs |
| | from nerf.stats import Stats |
| | from omegaconf import DictConfig |
| | from visdom import Visdom |
| |
|
| |
|
| | CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs") |
| |
|
| |
|
| | @hydra.main(config_path=CONFIG_DIR, config_name="lego") |
| | def main(cfg: DictConfig): |
| |
|
| | |
| | np.random.seed(cfg.seed) |
| | torch.manual_seed(cfg.seed) |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | device = "cuda" |
| | else: |
| | warnings.warn( |
| | "Please note that although executing on CPU is supported," |
| | + "the training is unlikely to finish in reasonable time." |
| | ) |
| | device = "cpu" |
| |
|
| | |
| | model = RadianceFieldRenderer( |
| | image_size=cfg.data.image_size, |
| | n_pts_per_ray=cfg.raysampler.n_pts_per_ray, |
| | n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray, |
| | n_rays_per_image=cfg.raysampler.n_rays_per_image, |
| | min_depth=cfg.raysampler.min_depth, |
| | max_depth=cfg.raysampler.max_depth, |
| | stratified=cfg.raysampler.stratified, |
| | stratified_test=cfg.raysampler.stratified_test, |
| | chunk_size_test=cfg.raysampler.chunk_size_test, |
| | n_harmonic_functions_xyz=cfg.implicit_function.n_harmonic_functions_xyz, |
| | n_harmonic_functions_dir=cfg.implicit_function.n_harmonic_functions_dir, |
| | n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz, |
| | n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir, |
| | n_layers_xyz=cfg.implicit_function.n_layers_xyz, |
| | density_noise_std=cfg.implicit_function.density_noise_std, |
| | visualization=cfg.visualization.visdom, |
| | ) |
| |
|
| | |
| | model.to(device) |
| |
|
| | |
| | stats = None |
| | optimizer_state_dict = None |
| | start_epoch = 0 |
| |
|
| | checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path) |
| | if len(cfg.checkpoint_path) > 0: |
| | |
| | checkpoint_dir = os.path.split(checkpoint_path)[0] |
| | os.makedirs(checkpoint_dir, exist_ok=True) |
| |
|
| | |
| | if cfg.resume and os.path.isfile(checkpoint_path): |
| | print(f"Resuming from checkpoint {checkpoint_path}.") |
| | loaded_data = torch.load(checkpoint_path) |
| | model.load_state_dict(loaded_data["model"]) |
| | stats = pickle.loads(loaded_data["stats"]) |
| | print(f" => resuming from epoch {stats.epoch}.") |
| | optimizer_state_dict = loaded_data["optimizer"] |
| | start_epoch = stats.epoch |
| |
|
| | |
| | optimizer = torch.optim.Adam( |
| | model.parameters(), |
| | lr=cfg.optimizer.lr, |
| | ) |
| |
|
| | |
| | if optimizer_state_dict is not None: |
| | optimizer.load_state_dict(optimizer_state_dict) |
| | optimizer.last_epoch = start_epoch |
| |
|
| | |
| | if stats is None: |
| | stats = Stats( |
| | ["loss", "mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"], |
| | ) |
| |
|
| | |
| |
|
| | |
| | |
| | def lr_lambda(epoch): |
| | return cfg.optimizer.lr_scheduler_gamma ** ( |
| | epoch / cfg.optimizer.lr_scheduler_step_size |
| | ) |
| |
|
| | |
| | lr_scheduler = torch.optim.lr_scheduler.LambdaLR( |
| | optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False |
| | ) |
| |
|
| | |
| | visuals_cache = collections.deque(maxlen=cfg.visualization.history_size) |
| |
|
| | |
| | if cfg.visualization.visdom: |
| | viz = Visdom( |
| | server=cfg.visualization.visdom_server, |
| | port=cfg.visualization.visdom_port, |
| | use_incoming_socket=False, |
| | ) |
| | else: |
| | viz = None |
| |
|
| | |
| | train_dataset, val_dataset, _ = get_nerf_datasets( |
| | dataset_name=cfg.data.dataset_name, |
| | image_size=cfg.data.image_size, |
| | ) |
| |
|
| | if cfg.data.precache_rays: |
| | |
| | model.eval() |
| | with torch.no_grad(): |
| | for dataset in (train_dataset, val_dataset): |
| | cache_cameras = [e["camera"].to(device) for e in dataset] |
| | cache_camera_hashes = [e["camera_idx"] for e in dataset] |
| | model.precache_rays(cache_cameras, cache_camera_hashes) |
| |
|
| | train_dataloader = torch.utils.data.DataLoader( |
| | train_dataset, |
| | batch_size=1, |
| | shuffle=True, |
| | num_workers=0, |
| | collate_fn=trivial_collate, |
| | ) |
| |
|
| | |
| | val_dataloader = torch.utils.data.DataLoader( |
| | val_dataset, |
| | batch_size=1, |
| | num_workers=0, |
| | collate_fn=trivial_collate, |
| | sampler=torch.utils.data.RandomSampler( |
| | val_dataset, |
| | replacement=True, |
| | num_samples=cfg.optimizer.max_epochs, |
| | ), |
| | ) |
| |
|
| | |
| | model.train() |
| |
|
| | |
| | for epoch in range(start_epoch, cfg.optimizer.max_epochs): |
| | stats.new_epoch() |
| | for iteration, batch in enumerate(train_dataloader): |
| | image, camera, camera_idx = batch[0].values() |
| | image = image.to(device) |
| | camera = camera.to(device) |
| |
|
| | optimizer.zero_grad() |
| |
|
| | |
| | nerf_out, metrics = model( |
| | camera_idx if cfg.data.precache_rays else None, |
| | camera, |
| | image, |
| | ) |
| |
|
| | |
| | loss = metrics["mse_coarse"] + metrics["mse_fine"] |
| |
|
| | |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | |
| | stats.update( |
| | {"loss": float(loss), **metrics}, |
| | stat_set="train", |
| | ) |
| |
|
| | if iteration % cfg.stats_print_interval == 0: |
| | stats.print(stat_set="train") |
| |
|
| | |
| | if viz is not None: |
| | visuals_cache.append( |
| | { |
| | "camera": camera.cpu(), |
| | "camera_idx": camera_idx, |
| | "image": image.cpu().detach(), |
| | "rgb_fine": nerf_out["rgb_fine"].cpu().detach(), |
| | "rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(), |
| | "rgb_gt": nerf_out["rgb_gt"].cpu().detach(), |
| | "coarse_ray_bundle": nerf_out["coarse_ray_bundle"], |
| | } |
| | ) |
| |
|
| | |
| | lr_scheduler.step() |
| |
|
| | |
| | if epoch % cfg.validation_epoch_interval == 0 and epoch > 0: |
| |
|
| | |
| | val_batch = next(val_dataloader.__iter__()) |
| | val_image, val_camera, camera_idx = val_batch[0].values() |
| | val_image = val_image.to(device) |
| | val_camera = val_camera.to(device) |
| |
|
| | |
| | model.eval() |
| | with torch.no_grad(): |
| | val_nerf_out, val_metrics = model( |
| | camera_idx if cfg.data.precache_rays else None, |
| | val_camera, |
| | val_image, |
| | ) |
| |
|
| | |
| | stats.update(val_metrics, stat_set="val") |
| | stats.print(stat_set="val") |
| |
|
| | if viz is not None: |
| | |
| | stats.plot_stats( |
| | viz=viz, |
| | visdom_env=cfg.visualization.visdom_env, |
| | plot_file=None, |
| | ) |
| | |
| | visualize_nerf_outputs( |
| | val_nerf_out, visuals_cache, viz, cfg.visualization.visdom_env |
| | ) |
| |
|
| | |
| | model.train() |
| |
|
| | |
| | if ( |
| | epoch % cfg.checkpoint_epoch_interval == 0 |
| | and len(cfg.checkpoint_path) > 0 |
| | and epoch > 0 |
| | ): |
| | print(f"Storing checkpoint {checkpoint_path}.") |
| | data_to_store = { |
| | "model": model.state_dict(), |
| | "optimizer": optimizer.state_dict(), |
| | "stats": pickle.dumps(stats), |
| | } |
| | torch.save(data_to_store, checkpoint_path) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|