| |
|
|
| from omegaconf import DictConfig |
| from typing import List, Tuple, Dict, Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import lr_scheduler |
| import pytorch_lightning as pl |
| from pytorch_lightning.utilities import rank_zero_only |
|
|
| from einops import rearrange |
|
|
| from diffusers.schedulers import ( |
| DDPMScheduler, |
| DDIMScheduler, |
| KarrasVeScheduler, |
| DPMSolverMultistepScheduler |
| ) |
|
|
| from michelangelo.utils import instantiate_from_config |
| |
| from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule |
| from michelangelo.models.asl_diffusion.inference_utils import ddim_sample |
|
|
| SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler] |
|
|
|
|
| def disabled_train(self, mode=True): |
| """Overwrite model.train with this function to make sure train/eval mode |
| does not change anymore.""" |
| return self |
|
|
|
|
| class ASLDiffuser(pl.LightningModule): |
| first_stage_model: Optional[AlignedShapeAsLatentPLModule] |
| |
| model: nn.Module |
|
|
| def __init__(self, *, |
| first_stage_config, |
| denoiser_cfg, |
| scheduler_cfg, |
| optimizer_cfg, |
| loss_cfg, |
| first_stage_key: str = "surface", |
| cond_stage_key: str = "image", |
| cond_stage_trainable: bool = True, |
| scale_by_std: bool = False, |
| z_scale_factor: float = 1.0, |
| ckpt_path: Optional[str] = None, |
| ignore_keys: Union[Tuple[str], List[str]] = ()): |
|
|
| super().__init__() |
|
|
| self.first_stage_key = first_stage_key |
| self.cond_stage_key = cond_stage_key |
| self.cond_stage_trainable = cond_stage_trainable |
|
|
| |
| |
| self.first_stage_config = first_stage_config |
| self.first_stage_model = None |
| |
|
|
| |
| |
| self.cond_stage_model = { |
| "image": self.encode_image, |
| "image_unconditional_embedding": self.empty_img_cond, |
| "text": self.encode_text, |
| "text_unconditional_embedding": self.empty_text_cond, |
| "surface": self.encode_surface, |
| "surface_unconditional_embedding": self.empty_surface_cond, |
| } |
|
|
| |
| self.model = instantiate_from_config( |
| denoiser_cfg, device=None, dtype=None |
| ) |
|
|
| self.optimizer_cfg = optimizer_cfg |
|
|
| |
| self.scheduler_cfg = scheduler_cfg |
|
|
| self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise) |
| self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise) |
|
|
| |
| self.loss_cfg = loss_cfg |
|
|
| self.scale_by_std = scale_by_std |
| if scale_by_std: |
| self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor)) |
| else: |
| self.z_scale_factor = z_scale_factor |
|
|
| self.ckpt_path = ckpt_path |
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
|
|
| def instantiate_first_stage(self, config): |
| model = instantiate_from_config(config) |
| self.first_stage_model = model.eval() |
| self.first_stage_model.train = disabled_train |
| for param in self.first_stage_model.parameters(): |
| param.requires_grad = False |
|
|
| self.first_stage_model = self.first_stage_model.to(self.device) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def init_from_ckpt(self, path, ignore_keys=()): |
| state_dict = torch.load(path, map_location="cpu")["state_dict"] |
|
|
| keys = list(state_dict.keys()) |
| for k in keys: |
| for ik in ignore_keys: |
| if k.startswith(ik): |
| print("Deleting key {} from state_dict.".format(k)) |
| del state_dict[k] |
|
|
| missing, unexpected = self.load_state_dict(state_dict, strict=False) |
| print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") |
| if len(missing) > 0: |
| print(f"Missing Keys: {missing}") |
| print(f"Unexpected Keys: {unexpected}") |
|
|
| @property |
| def zero_rank(self): |
| if self._trainer: |
| zero_rank = self.trainer.local_rank == 0 |
| else: |
| zero_rank = True |
|
|
| return zero_rank |
|
|
| def configure_optimizers(self) -> Tuple[List, List]: |
|
|
| lr = self.learning_rate |
|
|
| trainable_parameters = list(self.model.parameters()) |
| |
|
|
| |
| |
| |
| |
|
|
| if self.optimizer_cfg is None: |
| optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] |
| schedulers = [] |
| else: |
| optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) |
| scheduler_func = instantiate_from_config( |
| self.optimizer_cfg.scheduler, |
| max_decay_steps=self.trainer.max_steps, |
| lr_max=lr |
| ) |
| scheduler = { |
| "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), |
| "interval": "step", |
| "frequency": 1 |
| } |
| optimizers = [optimizer] |
| schedulers = [scheduler] |
|
|
| return optimizers, schedulers |
|
|
| @torch.no_grad() |
| def encode_text(self, text): |
|
|
| b = text.shape[0] |
| text_tokens = rearrange(text, "b t l -> (b t) l") |
| text_embed = self.first_stage_model.model.encode_text_embed(text_tokens) |
| text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) |
| text_embed = text_embed.mean(dim=1) |
| text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) |
|
|
| return text_embed |
|
|
| @torch.no_grad() |
| def encode_image(self, img): |
|
|
| return self.first_stage_model.model.encode_image_embed(img) |
|
|
| @torch.no_grad() |
| def encode_surface(self, surface): |
|
|
| return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False) |
|
|
| @torch.no_grad() |
| def empty_text_cond(self, cond): |
|
|
| return torch.zeros_like(cond, device=cond.device) |
|
|
| @torch.no_grad() |
| def empty_img_cond(self, cond): |
|
|
| return torch.zeros_like(cond, device=cond.device) |
|
|
| @torch.no_grad() |
| def empty_surface_cond(self, cond): |
|
|
| return torch.zeros_like(cond, device=cond.device) |
|
|
| @torch.no_grad() |
| def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True): |
|
|
| z_q = self.first_stage_model.encode(surface, sample_posterior) |
| z_q = self.z_scale_factor * z_q |
|
|
| return z_q |
|
|
| @torch.no_grad() |
| def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs): |
|
|
| z_q = 1. / self.z_scale_factor * z_q |
| latents = self.first_stage_model.decode(z_q, **kwargs) |
| return latents |
|
|
| @rank_zero_only |
| @torch.no_grad() |
| def on_train_batch_start(self, batch, batch_idx): |
| |
| if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \ |
| and batch_idx == 0 and self.ckpt_path is None: |
| |
| print("### USING STD-RESCALING ###") |
|
|
| z_q = self.encode_first_stage(batch[self.first_stage_key]) |
| z = z_q.detach() |
|
|
| del self.z_scale_factor |
| self.register_buffer("z_scale_factor", 1. / z.flatten().std()) |
| print(f"setting self.z_scale_factor to {self.z_scale_factor}") |
|
|
| print("### USING STD-RESCALING ###") |
|
|
| def compute_loss(self, model_outputs, split): |
| """ |
| |
| Args: |
| model_outputs (dict): |
| - x_0: |
| - noise: |
| - noise_prior: |
| - noise_pred: |
| - noise_pred_prior: |
| |
| split (str): |
| |
| Returns: |
| |
| """ |
|
|
| pred = model_outputs["pred"] |
|
|
| if self.noise_scheduler.prediction_type == "epsilon": |
| target = model_outputs["noise"] |
| elif self.noise_scheduler.prediction_type == "sample": |
| target = model_outputs["x_0"] |
| else: |
| raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") |
|
|
| if self.loss_cfg.loss_type == "l1": |
| simple = F.l1_loss(pred, target, reduction="mean") |
| elif self.loss_cfg.loss_type in ["mse", "l2"]: |
| simple = F.mse_loss(pred, target, reduction="mean") |
| else: |
| raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.") |
|
|
| total_loss = simple |
|
|
| loss_dict = { |
| f"{split}/total_loss": total_loss.clone().detach(), |
| f"{split}/simple": simple.detach(), |
| } |
|
|
| return total_loss, loss_dict |
|
|
| def forward(self, batch): |
| """ |
| |
| Args: |
| batch: |
| |
| Returns: |
| |
| """ |
|
|
| if self.first_stage_model is None: |
| self.instantiate_first_stage(self.first_stage_config) |
|
|
| latents = self.encode_first_stage(batch[self.first_stage_key]) |
|
|
| |
|
|
| conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1) |
|
|
| mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1 |
| conditions = conditions * mask.to(conditions) |
|
|
| |
| |
| noise = torch.randn_like(latents) |
| bs = latents.shape[0] |
| |
| timesteps = torch.randint( |
| 0, |
| self.noise_scheduler.config.num_train_timesteps, |
| (bs,), |
| device=latents.device, |
| ) |
| timesteps = timesteps.long() |
| |
| noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) |
|
|
| |
| noise_pred = self.model(noisy_z, timesteps, conditions) |
|
|
| diffusion_outputs = { |
| "x_0": noisy_z, |
| "noise": noise, |
| "pred": noise_pred |
| } |
|
|
| return diffusion_outputs |
|
|
| def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]], |
| batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: |
| """ |
| |
| Args: |
| batch (dict): the batch sample, and it contains: |
| - surface (torch.FloatTensor): |
| - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1] |
| - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1] |
| - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1] |
| - text (list of str): |
| |
| batch_idx (int): |
| |
| optimizer_idx (int): |
| |
| Returns: |
| loss (torch.FloatTensor): |
| |
| """ |
|
|
| diffusion_outputs = self(batch) |
|
|
| loss, loss_dict = self.compute_loss(diffusion_outputs, "train") |
| self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) |
|
|
| return loss |
|
|
| def validation_step(self, batch: Dict[str, torch.FloatTensor], |
| batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: |
| """ |
| |
| Args: |
| batch (dict): the batch sample, and it contains: |
| - surface_pc (torch.FloatTensor): [n_pts, 4] |
| - surface_feats (torch.FloatTensor): [n_pts, c] |
| - text (list of str): |
| |
| batch_idx (int): |
| |
| optimizer_idx (int): |
| |
| Returns: |
| loss (torch.FloatTensor): |
| |
| """ |
|
|
| diffusion_outputs = self(batch) |
|
|
| loss, loss_dict = self.compute_loss(diffusion_outputs, "val") |
| self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) |
|
|
| return loss |
|
|
| @torch.no_grad() |
| def sample(self, |
| batch: Dict[str, Union[torch.FloatTensor, List[str]]], |
| sample_times: int = 1, |
| steps: Optional[int] = None, |
| guidance_scale: Optional[float] = None, |
| eta: float = 0.0, |
| return_intermediates: bool = False, **kwargs): |
|
|
| if self.first_stage_model is None: |
| self.instantiate_first_stage(self.first_stage_config) |
|
|
| if steps is None: |
| steps = self.scheduler_cfg.num_inference_steps |
|
|
| if guidance_scale is None: |
| guidance_scale = self.scheduler_cfg.guidance_scale |
| do_classifier_free_guidance = guidance_scale > 0 |
|
|
| |
| xc = batch[self.cond_stage_key] |
| |
| cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1) |
|
|
| if do_classifier_free_guidance: |
| """ |
| Note: There are two kinds of uncond for text. |
| 1: using "" as uncond text; (in SAL diffusion) |
| 2: zeros_like(cond) as uncond text; (in MDM) |
| """ |
| |
| un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond) |
| |
| cond = torch.cat([un_cond, cond], dim=0) |
|
|
| outputs = [] |
| latents = None |
|
|
| if not return_intermediates: |
| for _ in range(sample_times): |
| sample_loop = ddim_sample( |
| self.denoise_scheduler, |
| self.model, |
| shape=self.first_stage_model.latent_shape, |
| cond=cond, |
| steps=steps, |
| guidance_scale=guidance_scale, |
| do_classifier_free_guidance=do_classifier_free_guidance, |
| device=self.device, |
| eta=eta, |
| disable_prog=not self.zero_rank |
| ) |
| for sample, t in sample_loop: |
| latents = sample |
| outputs.append(self.decode_first_stage(latents, **kwargs)) |
| else: |
|
|
| sample_loop = ddim_sample( |
| self.denoise_scheduler, |
| self.model, |
| shape=self.first_stage_model.latent_shape, |
| cond=cond, |
| steps=steps, |
| guidance_scale=guidance_scale, |
| do_classifier_free_guidance=do_classifier_free_guidance, |
| device=self.device, |
| eta=eta, |
| disable_prog=not self.zero_rank |
| ) |
|
|
| iter_size = steps // sample_times |
| i = 0 |
| for sample, t in sample_loop: |
| latents = sample |
| if i % iter_size == 0 or i == steps - 1: |
| outputs.append(self.decode_first_stage(latents, **kwargs)) |
| i += 1 |
|
|
| return outputs |
|
|