| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import numpy as np |
| import os |
| import shutil |
| import torch |
| from PIL import Image |
| from datetime import datetime |
| from diffusers import DDPMScheduler, DDIMScheduler |
| from omegaconf import OmegaConf |
| from torch.nn import Conv2d |
| from torch.nn.parameter import Parameter |
| from torch.optim import Adam |
| from torch.optim.lr_scheduler import LambdaLR |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from typing import List, Union |
|
|
| from marigold.marigold_normals_pipeline import ( |
| MarigoldNormalsPipeline, |
| MarigoldNormalsOutput, |
| ) |
| from src.util.image_util import img_chw2hwc |
| from src.util import metric |
| from src.util.data_loader import skip_first_batches |
| from src.util.logging_util import tb_logger, eval_dict_to_text |
| from src.util.loss import get_loss |
| from src.util.lr_scheduler import IterExponential |
| from src.util.metric import MetricTracker, compute_cosine_error |
| from src.util.multi_res_noise import multi_res_noise_like |
| from src.util.seeding import generate_seed_sequence |
|
|
|
|
| class MarigoldNormalsTrainer: |
| def __init__( |
| self, |
| cfg: OmegaConf, |
| model: MarigoldNormalsPipeline, |
| train_dataloader: DataLoader, |
| device, |
| out_dir_ckpt, |
| out_dir_eval, |
| out_dir_vis, |
| accumulation_steps: int, |
| val_dataloaders: List[DataLoader] = None, |
| vis_dataloaders: List[DataLoader] = None, |
| ): |
| self.cfg: OmegaConf = cfg |
| self.model: MarigoldNormalsPipeline = model |
| self.device = device |
| self.seed: Union[int, None] = ( |
| self.cfg.trainer.init_seed |
| ) |
| self.out_dir_ckpt = out_dir_ckpt |
| self.out_dir_eval = out_dir_eval |
| self.out_dir_vis = out_dir_vis |
| self.train_loader: DataLoader = train_dataloader |
| self.val_loaders: List[DataLoader] = val_dataloaders |
| self.vis_loaders: List[DataLoader] = vis_dataloaders |
| self.accumulation_steps: int = accumulation_steps |
|
|
| |
| if 8 != self.model.unet.config["in_channels"]: |
| self._replace_unet_conv_in() |
|
|
| |
| self.model.encode_empty_text() |
| self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) |
|
|
| self.model.unet.enable_xformers_memory_efficient_attention() |
|
|
| |
| self.model.vae.requires_grad_(False) |
| self.model.text_encoder.requires_grad_(False) |
| self.model.unet.requires_grad_(True) |
|
|
| |
| lr = self.cfg.lr |
| self.optimizer = Adam(self.model.unet.parameters(), lr=lr) |
|
|
| |
| lr_func = IterExponential( |
| total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, |
| final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, |
| warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, |
| ) |
| self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) |
|
|
| |
| self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) |
|
|
| |
| self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_config( |
| self.model.scheduler.config, |
| rescale_betas_zero_snr=True, |
| timestep_spacing="trailing", |
| ) |
|
|
| logging.info( |
| "DDPM training noise scheduler config is updated: " |
| f"rescale_betas_zero_snr = {self.training_noise_scheduler.config.rescale_betas_zero_snr}, " |
| f"timestep_spacing = {self.training_noise_scheduler.config.timestep_spacing}" |
| ) |
|
|
| self.prediction_type = self.training_noise_scheduler.config.prediction_type |
| assert ( |
| self.prediction_type == self.model.scheduler.config.prediction_type |
| ), "Different prediction types" |
| self.scheduler_timesteps = ( |
| self.training_noise_scheduler.config.num_train_timesteps |
| ) |
|
|
| |
| self.model.scheduler = DDIMScheduler.from_config( |
| self.training_noise_scheduler.config, |
| ) |
|
|
| |
| self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] |
|
|
| self.train_metrics = MetricTracker(*["loss"]) |
| self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) |
|
|
| |
| self.main_val_metric = cfg.validation.main_val_metric |
| self.main_val_metric_goal = cfg.validation.main_val_metric_goal |
|
|
| assert ( |
| self.main_val_metric in cfg.eval.eval_metrics |
| ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." |
|
|
| self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 |
|
|
| |
| self.max_epoch = self.cfg.max_epoch |
| self.max_iter = self.cfg.max_iter |
| self.gradient_accumulation_steps = accumulation_steps |
| self.gt_normals_type = self.cfg.gt_normals_type |
| self.gt_mask_type = self.cfg.gt_mask_type |
| self.save_period = self.cfg.trainer.save_period |
| self.backup_period = self.cfg.trainer.backup_period |
| self.val_period = self.cfg.trainer.validation_period |
| self.vis_period = self.cfg.trainer.visualization_period |
|
|
| |
| self.apply_multi_res_noise = self.cfg.multi_res_noise is not None |
| if self.apply_multi_res_noise: |
| self.mr_noise_strength = self.cfg.multi_res_noise.strength |
| self.annealed_mr_noise = self.cfg.multi_res_noise.annealed |
| self.mr_noise_downscale_strategy = ( |
| self.cfg.multi_res_noise.downscale_strategy |
| ) |
|
|
| |
| self.epoch = 1 |
| self.n_batch_in_epoch = 0 |
| self.effective_iter = 0 |
| self.in_evaluation = False |
| self.global_seed_sequence: List = [] |
|
|
| def _replace_unet_conv_in(self): |
| |
| _weight = self.model.unet.conv_in.weight.clone() |
| _bias = self.model.unet.conv_in.bias.clone() |
| _weight = _weight.repeat((1, 2, 1, 1)) |
| |
| _weight *= 0.5 |
| |
| _n_convin_out_channel = self.model.unet.conv_in.out_channels |
| _new_conv_in = Conv2d( |
| 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) |
| ) |
| _new_conv_in.weight = Parameter(_weight) |
| _new_conv_in.bias = Parameter(_bias) |
| self.model.unet.conv_in = _new_conv_in |
| logging.info("Unet conv_in layer is replaced") |
| |
| self.model.unet.config["in_channels"] = 8 |
| logging.info("Unet config is updated") |
| return |
|
|
| def train(self, t_end=None): |
| logging.info("Start training") |
|
|
| device = self.device |
| self.model.to(device) |
|
|
| if self.in_evaluation: |
| logging.info( |
| "Last evaluation was not finished, will do evaluation before continue training." |
| ) |
| self.validate() |
|
|
| self.train_metrics.reset() |
| accumulated_step = 0 |
|
|
| for epoch in range(self.epoch, self.max_epoch + 1): |
| self.epoch = epoch |
| logging.debug(f"epoch: {self.epoch}") |
|
|
| |
| for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): |
| self.model.unet.train() |
|
|
| |
| if self.seed is not None: |
| local_seed = self._get_next_seed() |
| rand_num_generator = torch.Generator(device=device) |
| rand_num_generator.manual_seed(local_seed) |
| else: |
| rand_num_generator = None |
|
|
| |
|
|
| |
| rgb = batch["rgb_norm"].to(device) |
| normals_gt_for_latent = batch[self.gt_normals_type].to(device) |
|
|
| if self.gt_mask_type is not None: |
| valid_mask_for_latent = batch[self.gt_mask_type].to(device) |
| invalid_mask = ~valid_mask_for_latent |
| valid_mask_down = ~torch.max_pool2d( |
| invalid_mask.float(), 8, 8 |
| ).bool() |
| valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1)) |
|
|
| batch_size = rgb.shape[0] |
|
|
| with torch.no_grad(): |
| |
| rgb_latent = self.encode_rgb(rgb) |
| |
| gt_target_latent = self.encode_rgb( |
| normals_gt_for_latent |
| ) |
|
|
| |
| timesteps = torch.randint( |
| 0, |
| self.scheduler_timesteps, |
| (batch_size,), |
| device=device, |
| generator=rand_num_generator, |
| ).long() |
|
|
| |
| if self.apply_multi_res_noise: |
| strength = self.mr_noise_strength |
| if self.annealed_mr_noise: |
| |
| strength = strength * (timesteps / self.scheduler_timesteps) |
| noise = multi_res_noise_like( |
| gt_target_latent, |
| strength=strength, |
| downscale_strategy=self.mr_noise_downscale_strategy, |
| generator=rand_num_generator, |
| device=device, |
| ) |
| else: |
| noise = torch.randn( |
| gt_target_latent.shape, |
| device=device, |
| generator=rand_num_generator, |
| ) |
|
|
| |
| noisy_latents = self.training_noise_scheduler.add_noise( |
| gt_target_latent, noise, timesteps |
| ) |
|
|
| |
| text_embed = self.empty_text_embed.to(device).repeat( |
| (batch_size, 1, 1) |
| ) |
|
|
| |
| cat_latents = torch.cat( |
| [rgb_latent, noisy_latents], dim=1 |
| ) |
| cat_latents = cat_latents.float() |
|
|
| |
| model_pred = self.model.unet( |
| cat_latents, timesteps, text_embed |
| ).sample |
| if torch.isnan(model_pred).any(): |
| logging.warning("model_pred contains NaN.") |
|
|
| |
| if "sample" == self.prediction_type: |
| target = gt_target_latent |
| elif "epsilon" == self.prediction_type: |
| target = noise |
| elif "v_prediction" == self.prediction_type: |
| target = self.training_noise_scheduler.get_velocity( |
| gt_target_latent, noise, timesteps |
| ) |
| else: |
| raise ValueError(f"Unknown prediction type {self.prediction_type}") |
|
|
| |
| if self.gt_mask_type is not None: |
| latent_loss = self.loss( |
| model_pred[valid_mask_down].float(), |
| target[valid_mask_down].float(), |
| ) |
| else: |
| latent_loss = self.loss(model_pred.float(), target.float()) |
|
|
| loss = latent_loss.mean() |
|
|
| self.train_metrics.update("loss", loss.item()) |
|
|
| loss = loss / self.gradient_accumulation_steps |
| loss.backward() |
| accumulated_step += 1 |
|
|
| self.n_batch_in_epoch += 1 |
| |
|
|
| |
| if accumulated_step >= self.gradient_accumulation_steps: |
| self.optimizer.step() |
| self.lr_scheduler.step() |
| self.optimizer.zero_grad() |
| accumulated_step = 0 |
|
|
| self.effective_iter += 1 |
|
|
| |
| accumulated_loss = self.train_metrics.result()["loss"] |
| tb_logger.log_dict( |
| { |
| f"train/{k}": v |
| for k, v in self.train_metrics.result().items() |
| }, |
| global_step=self.effective_iter, |
| ) |
| tb_logger.writer.add_scalar( |
| "lr", |
| self.lr_scheduler.get_last_lr()[0], |
| global_step=self.effective_iter, |
| ) |
| tb_logger.writer.add_scalar( |
| "n_batch_in_epoch", |
| self.n_batch_in_epoch, |
| global_step=self.effective_iter, |
| ) |
| logging.info( |
| f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}" |
| ) |
| self.train_metrics.reset() |
|
|
| |
| self._train_step_callback() |
|
|
| |
| if self.max_iter > 0 and self.effective_iter >= self.max_iter: |
| self.save_checkpoint( |
| ckpt_name=self._get_backup_ckpt_name(), |
| save_train_state=False, |
| ) |
| logging.info("Training ended.") |
| return |
| |
| elif t_end is not None and datetime.now() >= t_end: |
| self.save_checkpoint(ckpt_name="latest", save_train_state=True) |
| logging.info("Time is up, training paused.") |
| return |
|
|
| torch.cuda.empty_cache() |
| |
|
|
| |
| self.n_batch_in_epoch = 0 |
|
|
| def encode_rgb(self, image_in): |
| assert len(image_in.shape) == 4 and image_in.shape[1] == 3 |
| latent = self.model.encode_rgb(image_in) |
| return latent |
|
|
| def _train_step_callback(self): |
| """Executed after every iteration""" |
| |
| if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: |
| self.save_checkpoint( |
| ckpt_name=self._get_backup_ckpt_name(), save_train_state=False |
| ) |
|
|
| _is_latest_saved = False |
| |
| if self.val_period > 0 and 0 == self.effective_iter % self.val_period: |
| self.in_evaluation = True |
| self.save_checkpoint(ckpt_name="latest", save_train_state=True) |
| _is_latest_saved = True |
| self.validate() |
| self.in_evaluation = False |
| self.save_checkpoint(ckpt_name="latest", save_train_state=True) |
|
|
| |
| if ( |
| self.save_period > 0 |
| and 0 == self.effective_iter % self.save_period |
| and not _is_latest_saved |
| ): |
| self.save_checkpoint(ckpt_name="latest", save_train_state=True) |
|
|
| |
| if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period: |
| self.visualize() |
|
|
| def validate(self): |
| for i, val_loader in enumerate(self.val_loaders): |
| val_dataset_name = val_loader.dataset.disp_name |
| val_metric_dict = self.validate_single_dataset( |
| data_loader=val_loader, metric_tracker=self.val_metrics |
| ) |
| logging.info( |
| f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dict}" |
| ) |
| tb_logger.log_dict( |
| {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dict.items()}, |
| global_step=self.effective_iter, |
| ) |
| |
| eval_text = eval_dict_to_text( |
| val_metrics=val_metric_dict, |
| dataset_name=val_dataset_name, |
| sample_list_path=val_loader.dataset.filename_ls_path, |
| ) |
| _save_to = os.path.join( |
| self.out_dir_eval, |
| f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", |
| ) |
| with open(_save_to, "w+") as f: |
| f.write(eval_text) |
|
|
| |
| if 0 == i: |
| main_eval_metric = val_metric_dict[self.main_val_metric] |
| if ( |
| "minimize" == self.main_val_metric_goal |
| and main_eval_metric < self.best_metric |
| or "maximize" == self.main_val_metric_goal |
| and main_eval_metric > self.best_metric |
| ): |
| self.best_metric = main_eval_metric |
| logging.info( |
| f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" |
| ) |
| |
| self.save_checkpoint( |
| ckpt_name=self._get_backup_ckpt_name(), save_train_state=False |
| ) |
|
|
| def visualize(self): |
| for val_loader in self.vis_loaders: |
| vis_dataset_name = val_loader.dataset.disp_name |
| vis_out_dir = os.path.join( |
| self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name |
| ) |
| os.makedirs(vis_out_dir, exist_ok=True) |
| _ = self.validate_single_dataset( |
| data_loader=val_loader, |
| metric_tracker=self.val_metrics, |
| save_to_dir=vis_out_dir, |
| ) |
|
|
| @torch.no_grad() |
| def validate_single_dataset( |
| self, |
| data_loader: DataLoader, |
| metric_tracker: MetricTracker, |
| save_to_dir: str = None, |
| ): |
| self.model.to(self.device) |
| metric_tracker.reset() |
|
|
| |
| val_init_seed = self.cfg.validation.init_seed |
| val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) |
|
|
| for i, batch in enumerate( |
| tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), |
| start=1, |
| ): |
| assert 1 == data_loader.batch_size |
| |
| rgb_int = batch["rgb_int"] |
| |
| normals_gt = batch["normals"].to(self.device) |
|
|
| |
| seed = val_seed_ls.pop() |
| if seed is None: |
| generator = None |
| else: |
| generator = torch.Generator(device=self.device) |
| generator.manual_seed(seed) |
|
|
| |
| pipe_out: MarigoldNormalsOutput = self.model( |
| rgb_int, |
| denoising_steps=self.cfg.validation.denoising_steps, |
| ensemble_size=self.cfg.validation.ensemble_size, |
| processing_res=self.cfg.validation.processing_res, |
| match_input_res=self.cfg.validation.match_input_res, |
| generator=generator, |
| batch_size=1, |
| show_progress_bar=False, |
| resample_method=self.cfg.validation.resample_method, |
| ) |
|
|
| normals_pred = pipe_out.normals_np |
|
|
| normals_pred_ts = ( |
| torch.from_numpy(normals_pred).unsqueeze(0).to(self.device) |
| ) |
| cosine_error = compute_cosine_error( |
| normals_pred_ts, normals_gt, masked=True |
| ) |
| sample_metric = [] |
|
|
| for met_func in self.metric_funcs: |
| _metric_name = met_func.__name__ |
| _metric = met_func(cosine_error).item() |
| sample_metric.append(_metric.__str__()) |
| metric_tracker.update(_metric_name, _metric) |
|
|
| |
| if save_to_dir is not None: |
| img_name = batch["rgb_relative_path"][0].replace("/", "_") |
| png_save_path = os.path.join(save_to_dir, img_name) |
| normals_to_save = img_chw2hwc(((normals_pred + 1) * 127.5)).astype( |
| np.uint8 |
| ) |
| Image.fromarray(normals_to_save).save(png_save_path) |
|
|
| return metric_tracker.result() |
|
|
| def _get_next_seed(self): |
| if 0 == len(self.global_seed_sequence): |
| self.global_seed_sequence = generate_seed_sequence( |
| initial_seed=self.seed, |
| length=self.max_iter * self.gradient_accumulation_steps, |
| ) |
| logging.info( |
| f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" |
| ) |
| return self.global_seed_sequence.pop() |
|
|
| def save_checkpoint(self, ckpt_name, save_train_state): |
| ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) |
| logging.info(f"Saving checkpoint to: {ckpt_dir}") |
| |
| temp_ckpt_dir = None |
| if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): |
| temp_ckpt_dir = os.path.join( |
| os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" |
| ) |
| if os.path.exists(temp_ckpt_dir): |
| shutil.rmtree(temp_ckpt_dir, ignore_errors=True) |
| os.rename(ckpt_dir, temp_ckpt_dir) |
| logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") |
|
|
| |
| unet_path = os.path.join(ckpt_dir, "unet") |
| self.model.unet.save_pretrained(unet_path, safe_serialization=True) |
| logging.info(f"UNet is saved to: {unet_path}") |
|
|
| |
| scheduelr_path = os.path.join(ckpt_dir, "scheduler") |
| self.model.scheduler.save_pretrained(scheduelr_path) |
| logging.info(f"Scheduler is saved to: {scheduelr_path}") |
|
|
| if save_train_state: |
| state = { |
| "optimizer": self.optimizer.state_dict(), |
| "lr_scheduler": self.lr_scheduler.state_dict(), |
| "config": self.cfg, |
| "effective_iter": self.effective_iter, |
| "epoch": self.epoch, |
| "n_batch_in_epoch": self.n_batch_in_epoch, |
| "best_metric": self.best_metric, |
| "in_evaluation": self.in_evaluation, |
| "global_seed_sequence": self.global_seed_sequence, |
| } |
| train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") |
| torch.save(state, train_state_path) |
| |
| f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") |
| f.close() |
|
|
| logging.info(f"Trainer state is saved to: {train_state_path}") |
|
|
| |
| if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): |
| shutil.rmtree(temp_ckpt_dir, ignore_errors=True) |
| logging.debug("Old checkpoint backup is removed.") |
|
|
| def load_checkpoint( |
| self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True |
| ): |
| logging.info(f"Loading checkpoint from: {ckpt_path}") |
| |
| _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") |
| self.model.unet.load_state_dict( |
| torch.load(_model_path, map_location=self.device) |
| ) |
| self.model.unet.to(self.device) |
| logging.info(f"UNet parameters are loaded from {_model_path}") |
|
|
| |
| if load_trainer_state: |
| checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) |
| self.effective_iter = checkpoint["effective_iter"] |
| self.epoch = checkpoint["epoch"] |
| self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] |
| self.in_evaluation = checkpoint["in_evaluation"] |
| self.global_seed_sequence = checkpoint["global_seed_sequence"] |
|
|
| self.best_metric = checkpoint["best_metric"] |
|
|
| self.optimizer.load_state_dict(checkpoint["optimizer"]) |
| logging.info(f"optimizer state is loaded from {ckpt_path}") |
|
|
| if resume_lr_scheduler: |
| self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) |
| logging.info(f"LR scheduler state is loaded from {ckpt_path}") |
|
|
| logging.info( |
| f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" |
| ) |
| return |
|
|
| def _get_backup_ckpt_name(self): |
| return f"iter_{self.effective_iter:06d}" |
|
|