Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import collections | |
| import os | |
| import pickle | |
| import warnings | |
| import hydra | |
| import numpy as np | |
| import torch | |
| from nerf.dataset import get_nerf_datasets, trivial_collate | |
| from nerf.nerf_renderer import RadianceFieldRenderer, visualize_nerf_outputs | |
| from nerf.stats import Stats | |
| from omegaconf import DictConfig | |
| from visdom import Visdom | |
| CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs") | |
| def main(cfg: DictConfig): | |
| # Set the relevant seeds for reproducibility. | |
| np.random.seed(cfg.seed) | |
| torch.manual_seed(cfg.seed) | |
| # Device on which to run. | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| warnings.warn( | |
| "Please note that although executing on CPU is supported," | |
| + "the training is unlikely to finish in reasonable time." | |
| ) | |
| device = "cpu" | |
| # Initialize the Radiance Field model. | |
| model = RadianceFieldRenderer( | |
| image_size=cfg.data.image_size, | |
| n_pts_per_ray=cfg.raysampler.n_pts_per_ray, | |
| n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray, | |
| n_rays_per_image=cfg.raysampler.n_rays_per_image, | |
| min_depth=cfg.raysampler.min_depth, | |
| max_depth=cfg.raysampler.max_depth, | |
| stratified=cfg.raysampler.stratified, | |
| stratified_test=cfg.raysampler.stratified_test, | |
| chunk_size_test=cfg.raysampler.chunk_size_test, | |
| n_harmonic_functions_xyz=cfg.implicit_function.n_harmonic_functions_xyz, | |
| n_harmonic_functions_dir=cfg.implicit_function.n_harmonic_functions_dir, | |
| n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz, | |
| n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir, | |
| n_layers_xyz=cfg.implicit_function.n_layers_xyz, | |
| density_noise_std=cfg.implicit_function.density_noise_std, | |
| visualization=cfg.visualization.visdom, | |
| ) | |
| # Move the model to the relevant device. | |
| model.to(device) | |
| # Init stats to None before loading. | |
| stats = None | |
| optimizer_state_dict = None | |
| start_epoch = 0 | |
| checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path) | |
| if len(cfg.checkpoint_path) > 0: | |
| # Make the root of the experiment directory. | |
| checkpoint_dir = os.path.split(checkpoint_path)[0] | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| # Resume training if requested. | |
| if cfg.resume and os.path.isfile(checkpoint_path): | |
| print(f"Resuming from checkpoint {checkpoint_path}.") | |
| loaded_data = torch.load(checkpoint_path) | |
| model.load_state_dict(loaded_data["model"]) | |
| stats = pickle.loads(loaded_data["stats"]) | |
| print(f" => resuming from epoch {stats.epoch}.") | |
| optimizer_state_dict = loaded_data["optimizer"] | |
| start_epoch = stats.epoch | |
| # Initialize the optimizer. | |
| optimizer = torch.optim.Adam( | |
| model.parameters(), | |
| lr=cfg.optimizer.lr, | |
| ) | |
| # Load the optimizer state dict in case we are resuming. | |
| if optimizer_state_dict is not None: | |
| optimizer.load_state_dict(optimizer_state_dict) | |
| optimizer.last_epoch = start_epoch | |
| # Init the stats object. | |
| if stats is None: | |
| stats = Stats( | |
| ["loss", "mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"], | |
| ) | |
| # Learning rate scheduler setup. | |
| # Following the original code, we use exponential decay of the | |
| # learning rate: current_lr = base_lr * gamma ** (epoch / step_size) | |
| def lr_lambda(epoch): | |
| return cfg.optimizer.lr_scheduler_gamma ** ( | |
| epoch / cfg.optimizer.lr_scheduler_step_size | |
| ) | |
| # The learning rate scheduling is implemented with LambdaLR PyTorch scheduler. | |
| lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | |
| optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False | |
| ) | |
| # Initialize the cache for storing variables needed for visualization. | |
| visuals_cache = collections.deque(maxlen=cfg.visualization.history_size) | |
| # Init the visualization visdom env. | |
| if cfg.visualization.visdom: | |
| viz = Visdom( | |
| server=cfg.visualization.visdom_server, | |
| port=cfg.visualization.visdom_port, | |
| use_incoming_socket=False, | |
| ) | |
| else: | |
| viz = None | |
| # Load the training/validation data. | |
| train_dataset, val_dataset, _ = get_nerf_datasets( | |
| dataset_name=cfg.data.dataset_name, | |
| image_size=cfg.data.image_size, | |
| ) | |
| if cfg.data.precache_rays: | |
| # Precache the projection rays. | |
| model.eval() | |
| with torch.no_grad(): | |
| for dataset in (train_dataset, val_dataset): | |
| cache_cameras = [e["camera"].to(device) for e in dataset] | |
| cache_camera_hashes = [e["camera_idx"] for e in dataset] | |
| model.precache_rays(cache_cameras, cache_camera_hashes) | |
| train_dataloader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| batch_size=1, | |
| shuffle=True, | |
| num_workers=0, | |
| collate_fn=trivial_collate, | |
| ) | |
| # The validation dataloader is just an endless stream of random samples. | |
| val_dataloader = torch.utils.data.DataLoader( | |
| val_dataset, | |
| batch_size=1, | |
| num_workers=0, | |
| collate_fn=trivial_collate, | |
| sampler=torch.utils.data.RandomSampler( | |
| val_dataset, | |
| replacement=True, | |
| num_samples=cfg.optimizer.max_epochs, | |
| ), | |
| ) | |
| # Set the model to the training mode. | |
| model.train() | |
| # Run the main training loop. | |
| for epoch in range(start_epoch, cfg.optimizer.max_epochs): | |
| stats.new_epoch() # Init a new epoch. | |
| for iteration, batch in enumerate(train_dataloader): | |
| image, camera, camera_idx = batch[0].values() | |
| image = image.to(device) | |
| camera = camera.to(device) | |
| optimizer.zero_grad() | |
| # Run the forward pass of the model. | |
| nerf_out, metrics = model( | |
| camera_idx if cfg.data.precache_rays else None, | |
| camera, | |
| image, | |
| ) | |
| # The loss is a sum of coarse and fine MSEs | |
| loss = metrics["mse_coarse"] + metrics["mse_fine"] | |
| # Take the training step. | |
| loss.backward() | |
| optimizer.step() | |
| # Update stats with the current metrics. | |
| stats.update( | |
| {"loss": float(loss), **metrics}, | |
| stat_set="train", | |
| ) | |
| if iteration % cfg.stats_print_interval == 0: | |
| stats.print(stat_set="train") | |
| # Update the visualization cache. | |
| if viz is not None: | |
| visuals_cache.append( | |
| { | |
| "camera": camera.cpu(), | |
| "camera_idx": camera_idx, | |
| "image": image.cpu().detach(), | |
| "rgb_fine": nerf_out["rgb_fine"].cpu().detach(), | |
| "rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(), | |
| "rgb_gt": nerf_out["rgb_gt"].cpu().detach(), | |
| "coarse_ray_bundle": nerf_out["coarse_ray_bundle"], | |
| } | |
| ) | |
| # Adjust the learning rate. | |
| lr_scheduler.step() | |
| # Validation | |
| if epoch % cfg.validation_epoch_interval == 0 and epoch > 0: | |
| # Sample a validation camera/image. | |
| val_batch = next(val_dataloader.__iter__()) | |
| val_image, val_camera, camera_idx = val_batch[0].values() | |
| val_image = val_image.to(device) | |
| val_camera = val_camera.to(device) | |
| # Activate eval mode of the model (lets us do a full rendering pass). | |
| model.eval() | |
| with torch.no_grad(): | |
| val_nerf_out, val_metrics = model( | |
| camera_idx if cfg.data.precache_rays else None, | |
| val_camera, | |
| val_image, | |
| ) | |
| # Update stats with the validation metrics. | |
| stats.update(val_metrics, stat_set="val") | |
| stats.print(stat_set="val") | |
| if viz is not None: | |
| # Plot that loss curves into visdom. | |
| stats.plot_stats( | |
| viz=viz, | |
| visdom_env=cfg.visualization.visdom_env, | |
| plot_file=None, | |
| ) | |
| # Visualize the intermediate results. | |
| visualize_nerf_outputs( | |
| val_nerf_out, visuals_cache, viz, cfg.visualization.visdom_env | |
| ) | |
| # Set the model back to train mode. | |
| model.train() | |
| # Checkpoint. | |
| if ( | |
| epoch % cfg.checkpoint_epoch_interval == 0 | |
| and len(cfg.checkpoint_path) > 0 | |
| and epoch > 0 | |
| ): | |
| print(f"Storing checkpoint {checkpoint_path}.") | |
| data_to_store = { | |
| "model": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "stats": pickle.dumps(stats), | |
| } | |
| torch.save(data_to_store, checkpoint_path) | |
| if __name__ == "__main__": | |
| main() | |