Olbedo / src /trainer /marigold_normals_trainer.py
degbo's picture
update
7decfe1
raw
history blame
27.3 kB
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# More information about Marigold:
# https://marigoldmonodepth.github.io
# https://marigoldcomputervision.github.io
# Efficient inference pipelines are now part of diffusers:
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
# Examples of trained models and live demos:
# https://huggingface.co/prs-eth
# Related projects:
# https://rollingdepth.github.io/
# https://marigolddepthcompletion.github.io/
# Citation (BibTeX):
# https://github.com/prs-eth/Marigold#-citation
# If you find Marigold useful, we kindly ask you to cite our papers.
# --------------------------------------------------------------------------
import logging
import numpy as np
import os
import shutil
import torch
from PIL import Image
from datetime import datetime
from diffusers import DDPMScheduler, DDIMScheduler
from omegaconf import OmegaConf
from torch.nn import Conv2d
from torch.nn.parameter import Parameter
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import List, Union
from marigold.marigold_normals_pipeline import (
MarigoldNormalsPipeline,
MarigoldNormalsOutput,
)
from src.util.image_util import img_chw2hwc
from src.util import metric
from src.util.data_loader import skip_first_batches
from src.util.logging_util import tb_logger, eval_dict_to_text
from src.util.loss import get_loss
from src.util.lr_scheduler import IterExponential
from src.util.metric import MetricTracker, compute_cosine_error
from src.util.multi_res_noise import multi_res_noise_like
from src.util.seeding import generate_seed_sequence
class MarigoldNormalsTrainer:
def __init__(
self,
cfg: OmegaConf,
model: MarigoldNormalsPipeline,
train_dataloader: DataLoader,
device,
out_dir_ckpt,
out_dir_eval,
out_dir_vis,
accumulation_steps: int,
val_dataloaders: List[DataLoader] = None,
vis_dataloaders: List[DataLoader] = None,
):
self.cfg: OmegaConf = cfg
self.model: MarigoldNormalsPipeline = model
self.device = device
self.seed: Union[int, None] = (
self.cfg.trainer.init_seed
) # used to generate seed sequence, set to `None` to train w/o seeding
self.out_dir_ckpt = out_dir_ckpt
self.out_dir_eval = out_dir_eval
self.out_dir_vis = out_dir_vis
self.train_loader: DataLoader = train_dataloader
self.val_loaders: List[DataLoader] = val_dataloaders
self.vis_loaders: List[DataLoader] = vis_dataloaders
self.accumulation_steps: int = accumulation_steps
# Adapt input layers
if 8 != self.model.unet.config["in_channels"]:
self._replace_unet_conv_in()
# Encode empty text prompt
self.model.encode_empty_text()
self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device)
self.model.unet.enable_xformers_memory_efficient_attention()
# Trainability
self.model.vae.requires_grad_(False)
self.model.text_encoder.requires_grad_(False)
self.model.unet.requires_grad_(True)
# Optimizer !should be defined after input layer is adapted
lr = self.cfg.lr
self.optimizer = Adam(self.model.unet.parameters(), lr=lr)
# LR scheduler
lr_func = IterExponential(
total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter,
final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio,
warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps,
)
self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func)
# Loss
self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs)
# Training noise scheduler
self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_config(
self.model.scheduler.config,
rescale_betas_zero_snr=True,
timestep_spacing="trailing",
)
logging.info(
"DDPM training noise scheduler config is updated: "
f"rescale_betas_zero_snr = {self.training_noise_scheduler.config.rescale_betas_zero_snr}, "
f"timestep_spacing = {self.training_noise_scheduler.config.timestep_spacing}"
)
self.prediction_type = self.training_noise_scheduler.config.prediction_type
assert (
self.prediction_type == self.model.scheduler.config.prediction_type
), "Different prediction types"
self.scheduler_timesteps = (
self.training_noise_scheduler.config.num_train_timesteps
)
# Inference DDIM scheduler (used for validation)
self.model.scheduler = DDIMScheduler.from_config(
self.training_noise_scheduler.config,
)
# Eval metrics
self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics]
self.train_metrics = MetricTracker(*["loss"])
self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs])
# main metric for best checkpoint saving
self.main_val_metric = cfg.validation.main_val_metric
self.main_val_metric_goal = cfg.validation.main_val_metric_goal
assert (
self.main_val_metric in cfg.eval.eval_metrics
), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics."
self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8
# Settings
self.max_epoch = self.cfg.max_epoch
self.max_iter = self.cfg.max_iter
self.gradient_accumulation_steps = accumulation_steps
self.gt_normals_type = self.cfg.gt_normals_type
self.gt_mask_type = self.cfg.gt_mask_type
self.save_period = self.cfg.trainer.save_period
self.backup_period = self.cfg.trainer.backup_period
self.val_period = self.cfg.trainer.validation_period
self.vis_period = self.cfg.trainer.visualization_period
# Multi-resolution noise
self.apply_multi_res_noise = self.cfg.multi_res_noise is not None
if self.apply_multi_res_noise:
self.mr_noise_strength = self.cfg.multi_res_noise.strength
self.annealed_mr_noise = self.cfg.multi_res_noise.annealed
self.mr_noise_downscale_strategy = (
self.cfg.multi_res_noise.downscale_strategy
)
# Internal variables
self.epoch = 1
self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training
self.effective_iter = 0 # how many times optimizer.step() is called
self.in_evaluation = False
self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming
def _replace_unet_conv_in(self):
# replace the first layer to accept 8 in_channels
_weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3]
_bias = self.model.unet.conv_in.bias.clone() # [320]
_weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s)
# half the activation magnitude
_weight *= 0.5
# new conv_in channel
_n_convin_out_channel = self.model.unet.conv_in.out_channels
_new_conv_in = Conv2d(
8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
)
_new_conv_in.weight = Parameter(_weight)
_new_conv_in.bias = Parameter(_bias)
self.model.unet.conv_in = _new_conv_in
logging.info("Unet conv_in layer is replaced")
# replace config
self.model.unet.config["in_channels"] = 8
logging.info("Unet config is updated")
return
def train(self, t_end=None):
logging.info("Start training")
device = self.device
self.model.to(device)
if self.in_evaluation:
logging.info(
"Last evaluation was not finished, will do evaluation before continue training."
)
self.validate()
self.train_metrics.reset()
accumulated_step = 0
for epoch in range(self.epoch, self.max_epoch + 1):
self.epoch = epoch
logging.debug(f"epoch: {self.epoch}")
# Skip previous batches when resume
for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch):
self.model.unet.train()
# globally consistent random generators
if self.seed is not None:
local_seed = self._get_next_seed()
rand_num_generator = torch.Generator(device=device)
rand_num_generator.manual_seed(local_seed)
else:
rand_num_generator = None
# >>> With gradient accumulation >>>
# Get data
rgb = batch["rgb_norm"].to(device)
normals_gt_for_latent = batch[self.gt_normals_type].to(device)
if self.gt_mask_type is not None:
valid_mask_for_latent = batch[self.gt_mask_type].to(device)
invalid_mask = ~valid_mask_for_latent
valid_mask_down = ~torch.max_pool2d(
invalid_mask.float(), 8, 8
).bool()
valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1))
batch_size = rgb.shape[0]
with torch.no_grad():
# Encode image
rgb_latent = self.encode_rgb(rgb) # [B, 4, h, w]
# Encode GT normals
gt_target_latent = self.encode_rgb(
normals_gt_for_latent
) # [B, 4, h, w]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
self.scheduler_timesteps,
(batch_size,),
device=device,
generator=rand_num_generator,
).long() # [B]
# Sample noise
if self.apply_multi_res_noise:
strength = self.mr_noise_strength
if self.annealed_mr_noise:
# calculate strength depending on t
strength = strength * (timesteps / self.scheduler_timesteps)
noise = multi_res_noise_like(
gt_target_latent,
strength=strength,
downscale_strategy=self.mr_noise_downscale_strategy,
generator=rand_num_generator,
device=device,
)
else:
noise = torch.randn(
gt_target_latent.shape,
device=device,
generator=rand_num_generator,
) # [B, 4, h, w]
# Add noise to the latents (diffusion forward process)
noisy_latents = self.training_noise_scheduler.add_noise(
gt_target_latent, noise, timesteps
) # [B, 4, h, w]
# Text embedding
text_embed = self.empty_text_embed.to(device).repeat(
(batch_size, 1, 1)
) # [B, 77, 1024]
# Concat rgb and target latents
cat_latents = torch.cat(
[rgb_latent, noisy_latents], dim=1
) # [B, 8, h, w]
cat_latents = cat_latents.float()
# Predict the noise residual
model_pred = self.model.unet(
cat_latents, timesteps, text_embed
).sample # [B, 4, h, w]
if torch.isnan(model_pred).any():
logging.warning("model_pred contains NaN.")
# Get the target for loss depending on the prediction type
if "sample" == self.prediction_type:
target = gt_target_latent
elif "epsilon" == self.prediction_type:
target = noise
elif "v_prediction" == self.prediction_type:
target = self.training_noise_scheduler.get_velocity(
gt_target_latent, noise, timesteps
)
else:
raise ValueError(f"Unknown prediction type {self.prediction_type}")
# Masked latent loss
if self.gt_mask_type is not None:
latent_loss = self.loss(
model_pred[valid_mask_down].float(),
target[valid_mask_down].float(),
)
else:
latent_loss = self.loss(model_pred.float(), target.float())
loss = latent_loss.mean()
self.train_metrics.update("loss", loss.item())
loss = loss / self.gradient_accumulation_steps
loss.backward()
accumulated_step += 1
self.n_batch_in_epoch += 1
# Practical batch end
# Perform optimization step
if accumulated_step >= self.gradient_accumulation_steps:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
accumulated_step = 0
self.effective_iter += 1
# Log to tensorboard
accumulated_loss = self.train_metrics.result()["loss"]
tb_logger.log_dict(
{
f"train/{k}": v
for k, v in self.train_metrics.result().items()
},
global_step=self.effective_iter,
)
tb_logger.writer.add_scalar(
"lr",
self.lr_scheduler.get_last_lr()[0],
global_step=self.effective_iter,
)
tb_logger.writer.add_scalar(
"n_batch_in_epoch",
self.n_batch_in_epoch,
global_step=self.effective_iter,
)
logging.info(
f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}"
)
self.train_metrics.reset()
# Per-step callback
self._train_step_callback()
# End of training
if self.max_iter > 0 and self.effective_iter >= self.max_iter:
self.save_checkpoint(
ckpt_name=self._get_backup_ckpt_name(),
save_train_state=False,
)
logging.info("Training ended.")
return
# Time's up
elif t_end is not None and datetime.now() >= t_end:
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
logging.info("Time is up, training paused.")
return
torch.cuda.empty_cache()
# <<< Effective batch end <<<
# Epoch end
self.n_batch_in_epoch = 0
def encode_rgb(self, image_in):
assert len(image_in.shape) == 4 and image_in.shape[1] == 3
latent = self.model.encode_rgb(image_in)
return latent
def _train_step_callback(self):
"""Executed after every iteration"""
# Save backup (with a larger interval, without training states)
if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period:
self.save_checkpoint(
ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
)
_is_latest_saved = False
# Validation
if self.val_period > 0 and 0 == self.effective_iter % self.val_period:
self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
_is_latest_saved = True
self.validate()
self.in_evaluation = False
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
# Save training checkpoint (can be resumed)
if (
self.save_period > 0
and 0 == self.effective_iter % self.save_period
and not _is_latest_saved
):
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
# Visualization
if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period:
self.visualize()
def validate(self):
for i, val_loader in enumerate(self.val_loaders):
val_dataset_name = val_loader.dataset.disp_name
val_metric_dict = self.validate_single_dataset(
data_loader=val_loader, metric_tracker=self.val_metrics
)
logging.info(
f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dict}"
)
tb_logger.log_dict(
{f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dict.items()},
global_step=self.effective_iter,
)
# save to file
eval_text = eval_dict_to_text(
val_metrics=val_metric_dict,
dataset_name=val_dataset_name,
sample_list_path=val_loader.dataset.filename_ls_path,
)
_save_to = os.path.join(
self.out_dir_eval,
f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt",
)
with open(_save_to, "w+") as f:
f.write(eval_text)
# Update main eval metric
if 0 == i:
main_eval_metric = val_metric_dict[self.main_val_metric]
if (
"minimize" == self.main_val_metric_goal
and main_eval_metric < self.best_metric
or "maximize" == self.main_val_metric_goal
and main_eval_metric > self.best_metric
):
self.best_metric = main_eval_metric
logging.info(
f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}"
)
# Save a checkpoint
self.save_checkpoint(
ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
)
def visualize(self):
for val_loader in self.vis_loaders:
vis_dataset_name = val_loader.dataset.disp_name
vis_out_dir = os.path.join(
self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name
)
os.makedirs(vis_out_dir, exist_ok=True)
_ = self.validate_single_dataset(
data_loader=val_loader,
metric_tracker=self.val_metrics,
save_to_dir=vis_out_dir,
)
@torch.no_grad()
def validate_single_dataset(
self,
data_loader: DataLoader,
metric_tracker: MetricTracker,
save_to_dir: str = None,
):
self.model.to(self.device)
metric_tracker.reset()
# Generate seed sequence for consistent evaluation
val_init_seed = self.cfg.validation.init_seed
val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader))
for i, batch in enumerate(
tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"),
start=1,
):
assert 1 == data_loader.batch_size
# Read input image
rgb_int = batch["rgb_int"] # [B, 3, H, W]
# GT normals
normals_gt = batch["normals"].to(self.device) # [B, 3, H, W]
# Random number generator
seed = val_seed_ls.pop()
if seed is None:
generator = None
else:
generator = torch.Generator(device=self.device)
generator.manual_seed(seed)
# Predict normals
pipe_out: MarigoldNormalsOutput = self.model(
rgb_int,
denoising_steps=self.cfg.validation.denoising_steps,
ensemble_size=self.cfg.validation.ensemble_size,
processing_res=self.cfg.validation.processing_res,
match_input_res=self.cfg.validation.match_input_res,
generator=generator,
batch_size=1, # use batch size 1 to increase reproducibility
show_progress_bar=False,
resample_method=self.cfg.validation.resample_method,
)
normals_pred = pipe_out.normals_np # [3, H, W]
normals_pred_ts = (
torch.from_numpy(normals_pred).unsqueeze(0).to(self.device)
)
cosine_error = compute_cosine_error(
normals_pred_ts, normals_gt, masked=True
)
sample_metric = []
for met_func in self.metric_funcs:
_metric_name = met_func.__name__
_metric = met_func(cosine_error).item()
sample_metric.append(_metric.__str__())
metric_tracker.update(_metric_name, _metric)
# Save predicted normals as images
if save_to_dir is not None:
img_name = batch["rgb_relative_path"][0].replace("/", "_")
png_save_path = os.path.join(save_to_dir, img_name)
normals_to_save = img_chw2hwc(((normals_pred + 1) * 127.5)).astype(
np.uint8
)
Image.fromarray(normals_to_save).save(png_save_path)
return metric_tracker.result()
def _get_next_seed(self):
if 0 == len(self.global_seed_sequence):
self.global_seed_sequence = generate_seed_sequence(
initial_seed=self.seed,
length=self.max_iter * self.gradient_accumulation_steps,
)
logging.info(
f"Global seed sequence is generated, length={len(self.global_seed_sequence)}"
)
return self.global_seed_sequence.pop()
def save_checkpoint(self, ckpt_name, save_train_state):
ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
logging.info(f"Saving checkpoint to: {ckpt_dir}")
# Backup previous checkpoint
temp_ckpt_dir = None
if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir):
temp_ckpt_dir = os.path.join(
os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}"
)
if os.path.exists(temp_ckpt_dir):
shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
os.rename(ckpt_dir, temp_ckpt_dir)
logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}")
# Save UNet
unet_path = os.path.join(ckpt_dir, "unet")
self.model.unet.save_pretrained(unet_path, safe_serialization=True)
logging.info(f"UNet is saved to: {unet_path}")
# Save scheduler
scheduelr_path = os.path.join(ckpt_dir, "scheduler")
self.model.scheduler.save_pretrained(scheduelr_path)
logging.info(f"Scheduler is saved to: {scheduelr_path}")
if save_train_state:
state = {
"optimizer": self.optimizer.state_dict(),
"lr_scheduler": self.lr_scheduler.state_dict(),
"config": self.cfg,
"effective_iter": self.effective_iter,
"epoch": self.epoch,
"n_batch_in_epoch": self.n_batch_in_epoch,
"best_metric": self.best_metric,
"in_evaluation": self.in_evaluation,
"global_seed_sequence": self.global_seed_sequence,
}
train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
torch.save(state, train_state_path)
# iteration indicator
f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w")
f.close()
logging.info(f"Trainer state is saved to: {train_state_path}")
# Remove temp ckpt
if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir):
shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
logging.debug("Old checkpoint backup is removed.")
def load_checkpoint(
self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True
):
logging.info(f"Loading checkpoint from: {ckpt_path}")
# Load UNet
_model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin")
self.model.unet.load_state_dict(
torch.load(_model_path, map_location=self.device)
)
self.model.unet.to(self.device)
logging.info(f"UNet parameters are loaded from {_model_path}")
# Load training states
if load_trainer_state:
checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
self.effective_iter = checkpoint["effective_iter"]
self.epoch = checkpoint["epoch"]
self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
self.in_evaluation = checkpoint["in_evaluation"]
self.global_seed_sequence = checkpoint["global_seed_sequence"]
self.best_metric = checkpoint["best_metric"]
self.optimizer.load_state_dict(checkpoint["optimizer"])
logging.info(f"optimizer state is loaded from {ckpt_path}")
if resume_lr_scheduler:
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
logging.info(f"LR scheduler state is loaded from {ckpt_path}")
logging.info(
f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})"
)
return
def _get_backup_ckpt_name(self):
return f"iter_{self.effective_iter:06d}"