Spaces:
Sleeping
Sleeping
| """ | |
| Configurations can be overwritten by adding: key=value | |
| Use debug.wandb=False to disable logging to wandb. | |
| """ | |
| import datetime | |
| from datetime import timedelta | |
| import os | |
| import random | |
| import socket | |
| import time | |
| from glob import glob | |
| import hydra | |
| import ipdb # noqa: F401 | |
| import numpy as np | |
| import omegaconf | |
| import torch | |
| import wandb | |
| from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs | |
| from pytorch3d.renderer import PerspectiveCameras | |
| from diffusionsfm.dataset.co3d_v2 import Co3dDataset, unnormalize_image_for_vis | |
| # from diffusionsfm.dataset.multiloader import get_multiloader, MultiDataset | |
| from diffusionsfm.eval.eval_category import evaluate | |
| from diffusionsfm.model.diffuser import RayDiffuser | |
| from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT | |
| from diffusionsfm.model.scheduler import NoiseScheduler | |
| from diffusionsfm.utils.rays import cameras_to_rays, normalize_cameras_batch, compute_ndc_coordinates | |
| from diffusionsfm.utils.visualization import ( | |
| create_training_visualizations, | |
| view_color_coded_images_from_tensor, | |
| ) | |
| os.umask(000) # Default to 777 permissions | |
| class Trainer(object): | |
| def __init__(self, cfg): | |
| seed = cfg.training.seed | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| self.cfg = cfg | |
| self.debug = cfg.debug | |
| self.resume = cfg.training.resume | |
| self.pretrain_path = cfg.training.pretrain_path | |
| self.batch_size = cfg.training.batch_size | |
| self.max_iterations = cfg.training.max_iterations | |
| self.mixed_precision = cfg.training.mixed_precision | |
| self.interval_visualize = cfg.training.interval_visualize | |
| self.interval_save_checkpoint = cfg.training.interval_save_checkpoint | |
| self.interval_delete_checkpoint = cfg.training.interval_delete_checkpoint | |
| self.interval_evaluate = cfg.training.interval_evaluate | |
| self.delete_all = cfg.training.delete_all_checkpoints_after_training | |
| self.freeze_encoder = cfg.training.freeze_encoder | |
| self.translation_scale = cfg.training.translation_scale | |
| self.regression = cfg.training.regression | |
| self.prob_unconditional = cfg.training.prob_unconditional | |
| self.load_extra_cameras = cfg.training.load_extra_cameras | |
| self.calculate_intrinsics = cfg.training.calculate_intrinsics | |
| self.distort = cfg.training.distort | |
| self.diffuse_origins_and_endpoints = cfg.training.diffuse_origins_and_endpoints | |
| self.diffuse_depths = cfg.training.diffuse_depths | |
| self.depth_resolution = cfg.training.depth_resolution | |
| self.dpt_head = cfg.training.dpt_head | |
| self.full_num_patches_x = cfg.training.full_num_patches_x | |
| self.full_num_patches_y = cfg.training.full_num_patches_y | |
| self.dpt_encoder_features = cfg.training.dpt_encoder_features | |
| self.nearest_neighbor = cfg.training.nearest_neighbor | |
| self.no_bg_targets = cfg.training.no_bg_targets | |
| self.unit_normalize_scene = cfg.training.unit_normalize_scene | |
| self.sd_scale = cfg.training.sd_scale | |
| self.bfloat = cfg.training.bfloat | |
| self.first_cam_mediod = cfg.training.first_cam_mediod | |
| self.normalize_first_camera = cfg.training.normalize_first_camera | |
| self.gradient_clipping = cfg.training.gradient_clipping | |
| self.l1_loss = cfg.training.l1_loss | |
| self.reinit = cfg.training.reinit | |
| if self.first_cam_mediod: | |
| assert self.normalize_first_camera | |
| self.pred_x0 = cfg.model.pred_x0 | |
| self.num_patches_x = cfg.model.num_patches_x | |
| self.num_patches_y = cfg.model.num_patches_y | |
| self.depth = cfg.model.depth | |
| self.num_images = cfg.model.num_images | |
| self.num_visualize = min(self.batch_size, 2) | |
| self.random_num_images = cfg.model.random_num_images | |
| self.feature_extractor = cfg.model.feature_extractor | |
| self.append_ndc = cfg.model.append_ndc | |
| self.use_homogeneous = cfg.model.use_homogeneous | |
| self.freeze_transformer = cfg.model.freeze_transformer | |
| self.cond_depth_mask = cfg.model.cond_depth_mask | |
| self.dataset_name = cfg.dataset.name | |
| self.shape = cfg.dataset.shape | |
| self.apply_augmentation = cfg.dataset.apply_augmentation | |
| self.mask_holes = cfg.dataset.mask_holes | |
| self.image_size = cfg.dataset.image_size | |
| if not self.regression and (self.diffuse_origins_and_endpoints or self.diffuse_depths): | |
| assert self.mask_holes or self.cond_depth_mask | |
| if self.regression: | |
| assert self.pred_x0 | |
| self.start_time = None | |
| self.iteration = 0 | |
| self.epoch = 0 | |
| self.wandb_id = None | |
| self.hostname = socket.gethostname() | |
| if self.dpt_head: | |
| find_unused_parameters = True | |
| else: | |
| find_unused_parameters = False | |
| ddp_scaler = DistributedDataParallelKwargs( | |
| find_unused_parameters=find_unused_parameters | |
| ) | |
| init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400)) | |
| self.accelerator = Accelerator( | |
| even_batches=False, | |
| device_placement=False, | |
| kwargs_handlers=[ddp_scaler, init_kwargs], | |
| ) | |
| self.device = self.accelerator.device | |
| scheduler = NoiseScheduler( | |
| type=cfg.noise_scheduler.type, | |
| max_timesteps=cfg.noise_scheduler.max_timesteps, | |
| beta_start=cfg.noise_scheduler.beta_start, | |
| beta_end=cfg.noise_scheduler.beta_end, | |
| ) | |
| if self.dpt_head: | |
| self.model = RayDiffuserDPT( | |
| depth=self.depth, | |
| width=self.num_patches_x, | |
| P=1, | |
| max_num_images=self.num_images, | |
| noise_scheduler=scheduler, | |
| freeze_encoder=self.freeze_encoder, | |
| feature_extractor=self.feature_extractor, | |
| append_ndc=self.append_ndc, | |
| use_unconditional=self.prob_unconditional > 0, | |
| diffuse_depths=self.diffuse_depths, | |
| depth_resolution=self.depth_resolution, | |
| encoder_features=self.dpt_encoder_features, | |
| use_homogeneous=self.use_homogeneous, | |
| freeze_transformer=self.freeze_transformer, | |
| cond_depth_mask=self.cond_depth_mask, | |
| ).to(self.device) | |
| else: | |
| self.model = RayDiffuser( | |
| depth=self.depth, | |
| width=self.num_patches_x, | |
| P=1, | |
| max_num_images=self.num_images, | |
| noise_scheduler=scheduler, | |
| freeze_encoder=self.freeze_encoder, | |
| feature_extractor=self.feature_extractor, | |
| append_ndc=self.append_ndc, | |
| use_unconditional=self.prob_unconditional > 0, | |
| diffuse_depths=self.diffuse_depths, | |
| depth_resolution=self.depth_resolution, | |
| use_homogeneous=self.use_homogeneous, | |
| cond_depth_mask=self.cond_depth_mask, | |
| ).to(self.device) | |
| if self.dpt_head: | |
| depth_size = self.full_num_patches_x | |
| elif self.depth_resolution > 1: | |
| depth_size = self.num_patches_x * self.depth_resolution | |
| else: | |
| depth_size = self.num_patches_x | |
| self.depth_size = depth_size | |
| if self.dataset_name == "multi": | |
| self.dataset, self.train_dataloader, self.test_dataset = get_multiloader( | |
| num_images=self.num_images, | |
| apply_augmentation=self.apply_augmentation, | |
| load_extra_cameras=self.load_extra_cameras, | |
| distort_image=self.distort, | |
| center_crop=self.diffuse_origins_and_endpoints or self.diffuse_depths, | |
| crop_images=not (self.diffuse_origins_and_endpoints or self.diffuse_depths), | |
| load_depths=self.diffuse_origins_and_endpoints or self.diffuse_depths, | |
| depth_size=depth_size, | |
| mask_holes=self.mask_holes, | |
| img_size=self.image_size, | |
| batch_size=self.batch_size, | |
| num_workers=cfg.training.num_workers, | |
| dust3r_pairs=True, | |
| ) | |
| elif self.dataset_name == "co3d": | |
| self.dataset = Co3dDataset( | |
| category=self.shape, | |
| split="train", | |
| num_images=self.num_images, | |
| apply_augmentation=self.apply_augmentation, | |
| load_extra_cameras=self.load_extra_cameras, | |
| distort_image=self.distort, | |
| center_crop=self.diffuse_origins_and_endpoints or self.diffuse_depths, | |
| crop_images=not (self.diffuse_origins_and_endpoints or self.diffuse_depths), | |
| load_depths=self.diffuse_origins_and_endpoints or self.diffuse_depths, | |
| depth_size=depth_size, | |
| mask_holes=self.mask_holes, | |
| img_size=self.image_size, | |
| ) | |
| self.train_dataloader = torch.utils.data.DataLoader( | |
| self.dataset, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| num_workers=cfg.training.num_workers, | |
| pin_memory=True, | |
| drop_last=True, | |
| ) | |
| self.test_dataset = Co3dDataset( | |
| category=self.shape, | |
| split="test", | |
| num_images=self.num_images, | |
| apply_augmentation=False, | |
| load_extra_cameras=self.load_extra_cameras, | |
| distort_image=self.distort, | |
| center_crop=self.diffuse_origins_and_endpoints or self.diffuse_depths, | |
| crop_images=not (self.diffuse_origins_and_endpoints or self.diffuse_depths), | |
| load_depths=self.diffuse_origins_and_endpoints or self.diffuse_depths, | |
| depth_size=depth_size, | |
| mask_holes=self.mask_holes, | |
| img_size=self.image_size, | |
| ) | |
| else: | |
| raise NotImplementedError(f"Dataset '{self.dataset_name}' is not supported.") | |
| self.lr = 1e-4 | |
| self.output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir | |
| self.checkpoint_dir = os.path.join(self.output_dir, "checkpoints") | |
| if self.accelerator.is_main_process: | |
| name = os.path.basename(self.output_dir) | |
| name += f"_{self.debug.run_name}" | |
| print("Output dir:", self.output_dir) | |
| with open(os.path.join(self.output_dir, name), "w"): | |
| # Create empty tag with name | |
| pass | |
| self.name = name | |
| conf_dict = omegaconf.OmegaConf.to_container( | |
| cfg, resolve=True, throw_on_missing=True | |
| ) | |
| conf_dict["output_dir"] = self.output_dir | |
| conf_dict["hostname"] = self.hostname | |
| if self.dpt_head: | |
| self.init_optimizer_with_separate_lrs() | |
| else: | |
| self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) | |
| self.gradscaler = torch.cuda.amp.GradScaler(growth_interval=100000, enabled=self.mixed_precision) | |
| self.model, self.optimizer, self.train_dataloader = self.accelerator.prepare( | |
| self.model, self.optimizer, self.train_dataloader | |
| ) | |
| if self.resume: | |
| checkpoint_files = sorted(glob(os.path.join(self.checkpoint_dir, "*.pth"))) | |
| last_checkpoint = checkpoint_files[-1] | |
| print("Resuming from checkpoint:", last_checkpoint) | |
| self.load_model(last_checkpoint, load_metadata=True) | |
| elif self.pretrain_path != "": | |
| print("Loading pretrained model:", self.pretrain_path) | |
| self.load_model(self.pretrain_path, load_metadata=False) | |
| if self.accelerator.is_main_process: | |
| mode = "online" if cfg.debug.wandb else "disabled" | |
| if self.wandb_id is None: | |
| self.wandb_id = wandb.util.generate_id() | |
| self.wandb_run = wandb.init( | |
| mode=mode, | |
| name=name, | |
| project=cfg.debug.project_name, | |
| config=conf_dict, | |
| resume=self.resume, | |
| id=self.wandb_id, | |
| ) | |
| wandb.define_metric("iteration") | |
| noise_schedule = self.get_module().noise_scheduler.plot_schedule( | |
| return_image=True | |
| ) | |
| wandb.log( | |
| {"Schedule": wandb.Image(noise_schedule, caption="Noise Schedule")} | |
| ) | |
| def get_module(self): | |
| if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): | |
| model = self.model.module | |
| else: | |
| model = self.model | |
| return model | |
| def init_optimizer_with_separate_lrs(self): | |
| print("Use different LRs for the DINOv2 encoder and DiT!") | |
| feature_extractor_params = [ | |
| p for n, p in self.model.feature_extractor.named_parameters() | |
| ] | |
| feature_extractor_param_names = [ | |
| "feature_extractor." + n for n, _ in self.model.feature_extractor.named_parameters() | |
| ] | |
| ray_predictor_params = [ | |
| p for n, p in self.model.ray_predictor.named_parameters() | |
| ] | |
| ray_predictor_param_names = [ | |
| "ray_predictor." + n for n, p in self.model.ray_predictor.named_parameters() | |
| ] | |
| other_params = [ | |
| p for n, p in self.model.named_parameters() | |
| if n not in feature_extractor_param_names + ray_predictor_param_names | |
| ] | |
| self.optimizer = torch.optim.Adam([ | |
| {'params': feature_extractor_params, 'lr': self.lr * 0.1}, # Lower LR for feature extractor | |
| {'params': ray_predictor_params, 'lr': self.lr * 0.1}, # Lower LR for DIT (ray_predictor) | |
| {'params': other_params, 'lr': self.lr} # Normal LR for other parts of the model | |
| ]) | |
| def train(self): | |
| while self.iteration < self.max_iterations: | |
| for batch in self.train_dataloader: | |
| t0 = time.time() | |
| self.optimizer.zero_grad() | |
| float_type = torch.bfloat16 if self.bfloat else torch.float16 | |
| with torch.cuda.amp.autocast( | |
| enabled=self.mixed_precision, dtype=float_type | |
| ): | |
| images = batch["image"].to(self.device) | |
| focal_lengths = batch["focal_length"].to(self.device) | |
| crop_params = batch["crop_parameters"].to(self.device) | |
| principal_points = batch["principal_point"].to(self.device) | |
| R = batch["R"].to(self.device) | |
| T = batch["T"].to(self.device) | |
| if "distortion_coefficients" in batch: | |
| distortion_coefficients = batch["distortion_coefficients"] | |
| else: | |
| distortion_coefficients = [None for _ in range(R.shape[0])] | |
| depths = batch["depth"].to(self.device) | |
| if self.no_bg_targets: | |
| masks = batch["depth_masks"].to(self.device).bool() | |
| cameras_og = [ | |
| PerspectiveCameras( | |
| focal_length=focal_lengths[b], | |
| principal_point=principal_points[b], | |
| R=R[b], | |
| T=T[b], | |
| device=self.device, | |
| ) | |
| for b in range(self.batch_size) | |
| ] | |
| cameras, _ = normalize_cameras_batch( | |
| cameras=cameras_og, | |
| scale=self.translation_scale, | |
| normalize_first_camera=self.normalize_first_camera, | |
| depths=( | |
| None | |
| if not (self.diffuse_origins_and_endpoints or self.diffuse_depths) | |
| else depths | |
| ), | |
| first_cam_mediod=self.first_cam_mediod, | |
| crop_parameters=crop_params, | |
| num_patches_x=self.depth_size, | |
| num_patches_y=self.depth_size, | |
| distortion_coeffs=distortion_coefficients, | |
| ) | |
| # Now that cameras are normalized, fix shapes of camera parameters | |
| if self.load_extra_cameras or self.random_num_images: | |
| if self.random_num_images: | |
| num_images = torch.randint(2, self.num_images + 1, (1,)) | |
| else: | |
| num_images = self.num_images | |
| # The correct number of images is already loaded. | |
| # Only need to modify these camera parameters shapes. | |
| focal_lengths = focal_lengths[:, :num_images] | |
| crop_params = crop_params[:, :num_images] | |
| R = R[:, :num_images] | |
| T = T[:, :num_images] | |
| images = images[:, :num_images] | |
| depths = depths[:, :num_images] | |
| masks = masks[:, :num_images] | |
| cameras = [ | |
| PerspectiveCameras( | |
| focal_length=cameras[b].focal_length[:num_images], | |
| principal_point=cameras[b].principal_point[:num_images], | |
| R=cameras[b].R[:num_images], | |
| T=cameras[b].T[:num_images], | |
| device=self.device, | |
| ) | |
| for b in range(self.batch_size) | |
| ] | |
| if self.regression: | |
| low = self.get_module().noise_scheduler.max_timesteps - 1 | |
| else: | |
| low = 0 | |
| t = torch.randint( | |
| low=low, | |
| high=self.get_module().noise_scheduler.max_timesteps, | |
| size=(self.batch_size,), | |
| device=self.device, | |
| ) | |
| if self.prob_unconditional > 0: | |
| unconditional_mask = ( | |
| (torch.rand(self.batch_size) < self.prob_unconditional) | |
| .float() | |
| .to(self.device) | |
| ) | |
| else: | |
| unconditional_mask = None | |
| if self.distort: | |
| raise NotImplementedError() | |
| else: | |
| gt_rays = [] | |
| rays_dirs = [] | |
| rays = [] | |
| for i, (camera, crop_param, depth) in enumerate( | |
| zip(cameras, crop_params, depths) | |
| ): | |
| if self.diffuse_origins_and_endpoints: | |
| mode = "segment" | |
| else: | |
| mode = "plucker" | |
| r = cameras_to_rays( | |
| cameras=camera, | |
| num_patches_x=self.full_num_patches_x, | |
| num_patches_y=self.full_num_patches_y, | |
| crop_parameters=crop_param, | |
| depths=depth, | |
| mode=mode, | |
| depth_resolution=self.depth_resolution, | |
| nearest_neighbor=self.nearest_neighbor, | |
| distortion_coefficients=distortion_coefficients[i], | |
| ) | |
| rays_dirs.append(r.get_directions()) | |
| gt_rays.append(r) | |
| if self.diffuse_origins_and_endpoints: | |
| assert r.mode == "segment" | |
| elif self.diffuse_depths: | |
| assert r.mode == "plucker" | |
| if self.unit_normalize_scene: | |
| if self.diffuse_origins_and_endpoints: | |
| assert r.mode == "segment" | |
| # Let's say SD should be 0.5 | |
| scale = r.get_segments().std() * self.sd_scale | |
| if scale.isnan().any(): | |
| assert False | |
| camera.T /= scale | |
| r.rays /= scale | |
| depths[i] /= scale | |
| else: | |
| assert r.mode == "plucker" | |
| scale = r.depths.std() * self.sd_scale | |
| if scale.isnan().any(): | |
| assert False | |
| camera.T /= scale | |
| r.depths /= scale | |
| depths[i] /= scale | |
| rays.append( | |
| r.to_spatial( | |
| include_ndc_coordinates=self.append_ndc, | |
| include_depths=self.diffuse_depths, | |
| use_homogeneous=self.use_homogeneous, | |
| ) | |
| ) | |
| rays_tensor = torch.stack(rays, dim=0) | |
| if self.append_ndc: | |
| ndc_coordinates = rays_tensor[..., -2:, :, :] | |
| rays_tensor = rays_tensor[..., :-2, :, :] | |
| if self.dpt_head: | |
| xy_grid = compute_ndc_coordinates( | |
| crop_params, | |
| num_patches_x=self.depth_size // 16, | |
| num_patches_y=self.depth_size // 16, | |
| distortion_coeffs=distortion_coefficients, | |
| )[..., :2] | |
| ndc_coordinates = xy_grid.permute(0, 1, 4, 2, 3).contiguous() | |
| else: | |
| ndc_coordinates = None | |
| if self.cond_depth_mask: | |
| condition_mask = masks | |
| else: | |
| condition_mask = None | |
| if rays_tensor.isnan().any(): | |
| import pickle | |
| with open("bad.json", "wb") as f: | |
| pickle.dump(batch, f) | |
| ipdb.set_trace() | |
| eps_pred, eps = self.model( | |
| images=images, | |
| rays=rays_tensor, | |
| t=t, | |
| ndc_coordinates=ndc_coordinates, | |
| unconditional_mask=unconditional_mask, | |
| depth_mask=condition_mask, | |
| ) | |
| if self.pred_x0: | |
| target = rays_tensor | |
| else: | |
| target = eps | |
| if self.no_bg_targets: | |
| C = eps_pred.shape[2] | |
| loss_masks = masks.unsqueeze(2).repeat(1, 1, C, 1, 1) | |
| eps_pred = loss_masks * eps_pred | |
| target = loss_masks * target | |
| loss = 0 | |
| if self.l1_loss: | |
| loss_reconstruction = torch.mean(torch.abs(eps_pred - target)) | |
| else: | |
| loss_reconstruction = torch.mean((eps_pred - target) ** 2) | |
| loss += loss_reconstruction | |
| if self.mixed_precision: | |
| self.gradscaler.scale(loss).backward() | |
| scaled_norm = 0 | |
| for p in self.model.parameters(): | |
| if p.requires_grad and p.grad is not None: | |
| param_norm = p.grad.data.norm(2) | |
| scaled_norm += param_norm.item() ** 2 | |
| scaled_norm = scaled_norm ** 0.5 | |
| if self.gradient_clipping and self.accelerator.sync_gradients: | |
| self.accelerator.clip_grad_norm_( | |
| self.get_module().parameters(), 1 | |
| ) | |
| clipped_norm = 0 | |
| for p in self.model.parameters(): | |
| if p.requires_grad and p.grad is not None: | |
| param_norm = p.grad.data.norm(2) | |
| clipped_norm += param_norm.item() ** 2 | |
| clipped_norm = clipped_norm ** 0.5 | |
| self.gradscaler.unscale_(self.optimizer) | |
| unscaled_norm = 0 | |
| for p in self.model.parameters(): | |
| if p.requires_grad and p.grad is not None: | |
| param_norm = p.grad.data.norm(2) | |
| unscaled_norm += param_norm.item() ** 2 | |
| unscaled_norm = unscaled_norm ** 0.5 | |
| self.gradscaler.step(self.optimizer) | |
| self.gradscaler.update() | |
| else: | |
| self.accelerator.backward(loss) | |
| if self.gradient_clipping and self.accelerator.sync_gradients: | |
| self.accelerator.clip_grad_norm_( | |
| self.get_module().parameters(), 10 | |
| ) | |
| self.optimizer.step() | |
| if self.accelerator.is_main_process: | |
| if self.iteration % 10 == 0: | |
| self.log_info( | |
| loss_reconstruction, | |
| t0, | |
| self.lr, | |
| scaled_norm, | |
| unscaled_norm, | |
| clipped_norm, | |
| ) | |
| if self.iteration % self.interval_visualize == 0: | |
| self.visualize( | |
| images=unnormalize_image_for_vis(images.clone()), | |
| cameras_gt=cameras, | |
| depths=depths, | |
| crop_parameters=crop_params, | |
| distortion_coefficients=distortion_coefficients, | |
| depth_mask=masks, | |
| ) | |
| if self.iteration % self.interval_save_checkpoint == 0 and self.iteration != 0: | |
| self.save_model() | |
| if self.iteration % self.interval_delete_checkpoint == 0: | |
| self.clear_old_checkpoints(self.checkpoint_dir) | |
| if ( | |
| self.iteration % self.interval_evaluate == 0 | |
| and self.iteration > 0 | |
| ): | |
| self.evaluate_train_acc() | |
| if self.iteration >= self.max_iterations + 1: | |
| if self.delete_all: | |
| self.clear_old_checkpoints( | |
| self.checkpoint_dir, clear_all_old=True | |
| ) | |
| return | |
| self.iteration += 1 | |
| if self.reinit and self.iteration >= 50000: | |
| state_dict = self.get_module().state_dict() | |
| self.model = RayDiffuserDPT( | |
| depth=self.depth, | |
| width=self.num_patches_x, | |
| P=1, | |
| max_num_images=self.num_images, | |
| noise_scheduler=self.get_module().noise_scheduler, | |
| freeze_encoder=False, | |
| feature_extractor=self.feature_extractor, | |
| append_ndc=self.append_ndc, | |
| use_unconditional=self.prob_unconditional > 0, | |
| diffuse_depths=self.diffuse_depths, | |
| depth_resolution=self.depth_resolution, | |
| encoder_features=self.dpt_encoder_features, | |
| use_homogeneous=self.use_homogeneous, | |
| freeze_transformer=False, | |
| cond_depth_mask=self.cond_depth_mask, | |
| ).to(self.device) | |
| self.init_optimizer_with_separate_lrs() | |
| self.gradscaler = torch.cuda.amp.GradScaler(growth_interval=100000, enabled=self.mixed_precision) | |
| self.model, self.optimizer = self.accelerator.prepare( | |
| self.model, self.optimizer | |
| ) | |
| msg = self.get_module().load_state_dict( | |
| state_dict, | |
| strict=True, | |
| ) | |
| print(msg) | |
| self.reinit = False | |
| self.epoch += 1 | |
| def load_model(self, path, load_metadata=True): | |
| save_dict = torch.load(path, map_location=self.device) | |
| del save_dict["state_dict"]["ray_predictor.x_pos_enc.image_pos_table"] | |
| if not self.resume: | |
| if len(save_dict["state_dict"]["scratch.input_conv.weight"].shape) == 2 and self.dpt_head: | |
| print("Initialize conv layer weights from the linear layer!") | |
| C = save_dict["state_dict"]["scratch.input_conv.weight"].shape[1] | |
| input_conv_weight = save_dict["state_dict"]["scratch.input_conv.weight"].view(384, C, 1, 1).repeat(1, 1, 16, 16) / 256. | |
| input_conv_bias = save_dict["state_dict"]["scratch.input_conv.bias"] | |
| self.get_module().scratch.input_conv.weight.data = input_conv_weight | |
| self.get_module().scratch.input_conv.bias.data = input_conv_bias | |
| del save_dict["state_dict"]["scratch.input_conv.weight"] | |
| del save_dict["state_dict"]["scratch.input_conv.bias"] | |
| missing, unexpected = self.get_module().load_state_dict( | |
| save_dict["state_dict"], | |
| strict=False, | |
| ) | |
| print(f"Missing keys: {missing}") | |
| print(f"Unexpected keys: {unexpected}") | |
| if load_metadata: | |
| self.iteration = save_dict["iteration"] | |
| self.epoch = save_dict["epoch"] | |
| time_elapsed = save_dict["elapsed"] | |
| self.start_time = time.time() - time_elapsed | |
| if "wandb_id" in save_dict: | |
| self.wandb_id = save_dict["wandb_id"] | |
| self.optimizer.load_state_dict(save_dict["optimizer"]) | |
| self.gradscaler.load_state_dict(save_dict["gradscaler"]) | |
| def save_model(self): | |
| path = os.path.join(self.checkpoint_dir, f"ckpt_{self.iteration:08d}.pth") | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| elapsed = time.time() - self.start_time if self.start_time is not None else 0 | |
| save_dict = { | |
| "epoch": self.epoch, | |
| "elapsed": elapsed, | |
| "gradscaler": self.gradscaler.state_dict(), | |
| "iteration": self.iteration, | |
| "state_dict": self.get_module().state_dict(), | |
| "optimizer": self.optimizer.state_dict(), | |
| "wandb_id": self.wandb_id, | |
| } | |
| torch.save(save_dict, path) | |
| def clear_old_checkpoints(self, checkpoint_dir, clear_all_old=False): | |
| print("Clearing old checkpoints") | |
| checkpoint_files = sorted(glob(os.path.join(checkpoint_dir, "ckpt_*.pth"))) | |
| if clear_all_old: | |
| for checkpoint_file in checkpoint_files[:-1]: | |
| os.remove(checkpoint_file) | |
| else: | |
| for checkpoint_file in checkpoint_files: | |
| checkpoint = os.path.basename(checkpoint_file) | |
| checkpoint_iteration = int("".join(filter(str.isdigit, checkpoint))) | |
| if checkpoint_iteration % self.interval_delete_checkpoint != 0: | |
| os.remove(checkpoint_file) | |
| def log_info( | |
| self, | |
| loss, | |
| t0, | |
| lr, | |
| scaled_norm, | |
| unscaled_norm, | |
| clipped_norm, | |
| ): | |
| if self.start_time is None: | |
| self.start_time = time.time() | |
| time_elapsed = round(time.time() - self.start_time) | |
| time_remaining = round( | |
| (time.time() - self.start_time) | |
| / (self.iteration + 1) | |
| * (self.max_iterations - self.iteration) | |
| ) | |
| disp = [ | |
| f"Iter: {self.iteration}/{self.max_iterations}", | |
| f"Epoch: {self.epoch}", | |
| f"Loss: {loss.item():.4f}", | |
| f"LR: {lr:.7f}", | |
| f"Grad Norm: {scaled_norm:.4f}/{unscaled_norm:.4f}/{clipped_norm:.4f}", | |
| f"Elap: {str(datetime.timedelta(seconds=time_elapsed))}", | |
| f"Rem: {str(datetime.timedelta(seconds=time_remaining))}", | |
| self.hostname, | |
| self.name, | |
| ] | |
| print(", ".join(disp), flush=True) | |
| wandb_log = { | |
| "loss": loss.item(), | |
| "iter_time": time.time() - t0, | |
| "lr": lr, | |
| "iteration": self.iteration, | |
| "hours_remaining": time_remaining / 3600, | |
| "gradient norm": scaled_norm, | |
| "unscaled norm": unscaled_norm, | |
| "clipped norm": clipped_norm, | |
| } | |
| wandb.log(wandb_log) | |
| def visualize( | |
| self, | |
| images, | |
| cameras_gt, | |
| crop_parameters=None, | |
| depths=None, | |
| distortion_coefficients=None, | |
| depth_mask=None, | |
| high_loss=False, | |
| ): | |
| self.get_module().eval() | |
| for camera in cameras_gt: | |
| # AMP may not cast back to float | |
| camera.R = camera.R.float() | |
| camera.T = camera.T.float() | |
| loss_tag = "" if not high_loss else " HIGH LOSS" | |
| for i in range(self.num_visualize): | |
| imgs = view_color_coded_images_from_tensor(images[i].cpu(), depth=False) | |
| im = wandb.Image(imgs, caption=f"iteration {self.iteration} example {i}") | |
| wandb.log({f"Vis images {i}{loss_tag}": im}) | |
| if self.cond_depth_mask: | |
| imgs = view_color_coded_images_from_tensor( | |
| depth_mask[i].cpu(), depth=True | |
| ) | |
| im = wandb.Image( | |
| imgs, caption=f"iteration {self.iteration} example {i}" | |
| ) | |
| wandb.log({f"Vis masks {i}{loss_tag}": im}) | |
| vis_depths, _, _ = create_training_visualizations( | |
| model=self.get_module(), | |
| images=images[: self.num_visualize], | |
| device=self.device, | |
| cameras_gt=cameras_gt, | |
| pred_x0=self.pred_x0, | |
| num_images=images.shape[1], | |
| crop_parameters=crop_parameters[: self.num_visualize], | |
| visualize_pred=self.regression, | |
| return_first=self.regression, | |
| calculate_intrinsics=self.calculate_intrinsics, | |
| mode="segment" if self.diffuse_origins_and_endpoints else "plucker", | |
| depths=depths[: self.num_visualize], | |
| diffuse_depths=self.diffuse_depths, | |
| full_num_patches_x=self.full_num_patches_x, | |
| full_num_patches_y=self.full_num_patches_y, | |
| use_homogeneous=self.use_homogeneous, | |
| distortion_coefficients=distortion_coefficients, | |
| ) | |
| for i, vis_image in enumerate(vis_depths): | |
| im = wandb.Image( | |
| vis_image, caption=f"iteration {self.iteration} example {i}" | |
| ) | |
| for i, vis_image in enumerate(vis_depths): | |
| im = wandb.Image( | |
| vis_image, caption=f"iteration {self.iteration} example {i}" | |
| ) | |
| wandb.log({f"Vis origins and endpoints {i}{loss_tag}": im}) | |
| self.get_module().train() | |
| def evaluate_train_acc(self, num_evaluate=10): | |
| print("Evaluating train accuracy") | |
| model = self.get_module() | |
| model.eval() | |
| additional_timesteps = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90] | |
| num_images = self.num_images | |
| for split in ["train", "test"]: | |
| if split == "train": | |
| if self.dataset_name != "co3d": | |
| to_evaluate = self.dataset.datasets | |
| names = self.dataset.names | |
| else: | |
| to_evaluate = [self.dataset] | |
| names = ["co3d"] | |
| elif split == "test": | |
| if self.dataset_name != "co3d": | |
| to_evaluate = self.test_dataset.datasets | |
| names = self.test_dataset.names | |
| else: | |
| to_evaluate = [self.test_dataset] | |
| names = ["co3d"] | |
| for name, dataset in zip(names, to_evaluate): | |
| results = evaluate( | |
| cfg=self.cfg, | |
| model=model, | |
| dataset=dataset, | |
| num_images=num_images, | |
| device=self.device, | |
| additional_timesteps=additional_timesteps, | |
| num_evaluate=num_evaluate, | |
| use_pbar=True, | |
| mode="segment" if self.diffuse_origins_and_endpoints else "plucker", | |
| metrics=False, | |
| ) | |
| R_err = [] | |
| CC_err = [] | |
| for key in results.keys(): | |
| R_err.append([v["R_error"] for v in results[key]]) | |
| CC_err.append([v["CC_error"] for v in results[key]]) | |
| R_err = np.array(R_err) | |
| CC_err = np.array(CC_err) | |
| R_acc_15 = np.mean(R_err < 15, (0, 2)).max() | |
| CC_acc = np.mean(CC_err < 0.1, (0, 2)).max() | |
| wandb.log( | |
| { | |
| f"R_acc_15_{name}_{split}": R_acc_15, | |
| "iteration": self.iteration, | |
| } | |
| ) | |
| wandb.log( | |
| { | |
| f"CC_acc_0.1_{name}_{split}": CC_acc, | |
| "iteration": self.iteration, | |
| } | |
| ) | |
| model.train() | |
| def main(cfg): | |
| print(cfg) | |
| torch.autograd.set_detect_anomaly(cfg.debug.anomaly_detection) | |
| torch.set_float32_matmul_precision(cfg.training.matmul_precision) | |
| trainer = Trainer(cfg=cfg) | |
| trainer.train() | |
| if __name__ == "__main__": | |
| main() | |