| import os |
| import random |
| from dataclasses import dataclass, field |
|
|
| import torch |
| import torch.nn.functional as F |
| from diffusers import DDPMScheduler, UNet2DConditionModel |
| from diffusers.models import AutoencoderKL |
| from diffusers.training_utils import compute_snr |
| from einops import rearrange |
| from omegaconf import OmegaConf |
| from PIL import Image |
|
|
| from ..pipelines.ig2mv_sdxl_pipeline import IG2MVSDXLPipeline |
| from ..schedulers.scheduling_shift_snr import ShiftSNRScheduler |
| from ..utils.core import find |
| from ..utils.typing import * |
| from .base import BaseSystem |
| from .utils import encode_prompt, vae_encode |
|
|
|
|
| def compute_embeddings( |
| prompt_batch, |
| empty_prompt_indices, |
| text_encoders, |
| tokenizers, |
| is_train=True, |
| **kwargs, |
| ): |
| original_size = kwargs["original_size"] |
| target_size = kwargs["target_size"] |
| crops_coords_top_left = kwargs["crops_coords_top_left"] |
|
|
| for i in range(empty_prompt_indices.shape[0]): |
| if empty_prompt_indices[i]: |
| prompt_batch[i] = "" |
|
|
| prompt_embeds, pooled_prompt_embeds = encode_prompt( |
| prompt_batch, text_encoders, tokenizers, 0, is_train |
| ) |
| add_text_embeds = pooled_prompt_embeds.to( |
| device=prompt_embeds.device, dtype=prompt_embeds.dtype |
| ) |
|
|
| |
| add_time_ids = list(original_size + crops_coords_top_left + target_size) |
| add_time_ids = torch.tensor([add_time_ids]) |
| add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) |
| add_time_ids = add_time_ids.to( |
| device=prompt_embeds.device, dtype=prompt_embeds.dtype |
| ) |
|
|
| unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} |
|
|
| return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} |
|
|
|
|
| class IG2MVSDXLSystem(BaseSystem): |
| @dataclass |
| class Config(BaseSystem.Config): |
|
|
| |
| pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-xl-base-1.0" |
| pretrained_vae_name_or_path: Optional[str] = "madebyollin/sdxl-vae-fp16-fix" |
| pretrained_adapter_name_or_path: Optional[str] = None |
| pretrained_unet_name_or_path: Optional[str] = None |
| init_adapter_kwargs: Dict[str, Any] = field(default_factory=dict) |
|
|
| use_fp16_vae: bool = True |
| use_fp16_clip: bool = True |
|
|
| |
| trainable_modules: List[str] = field(default_factory=list) |
| train_cond_encoder: bool = True |
| prompt_drop_prob: float = 0.0 |
| image_drop_prob: float = 0.0 |
| cond_drop_prob: float = 0.0 |
|
|
| gradient_checkpointing: bool = False |
|
|
| |
| noise_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict) |
| noise_offset: float = 0.0 |
| input_perturbation: float = 0.0 |
| snr_gamma: Optional[float] = 5.0 |
| prediction_type: Optional[str] = None |
| shift_noise: bool = False |
| shift_noise_mode: str = "interpolated" |
| shift_noise_scale: float = 1.0 |
|
|
| |
| eval_seed: int = 0 |
| eval_num_inference_steps: int = 30 |
| eval_guidance_scale: float = 1.0 |
| eval_height: int = 512 |
| eval_width: int = 512 |
|
|
| cfg: Config |
|
|
| def configure(self): |
| super().configure() |
|
|
| |
| pipeline_kwargs = {} |
| if self.cfg.pretrained_vae_name_or_path is not None: |
| pipeline_kwargs["vae"] = AutoencoderKL.from_pretrained( |
| self.cfg.pretrained_vae_name_or_path |
| ) |
| if self.cfg.pretrained_unet_name_or_path is not None: |
| pipeline_kwargs["unet"] = UNet2DConditionModel.from_pretrained( |
| self.cfg.pretrained_unet_name_or_path |
| ) |
|
|
| pipeline: IG2MVSDXLPipeline |
| pipeline = IG2MVSDXLPipeline.from_pretrained( |
| self.cfg.pretrained_model_name_or_path, **pipeline_kwargs |
| ) |
|
|
| init_adapter_kwargs = OmegaConf.to_container(self.cfg.init_adapter_kwargs) |
| if "self_attn_processor" in init_adapter_kwargs: |
| self_attn_processor = init_adapter_kwargs["self_attn_processor"] |
| if self_attn_processor is not None and isinstance(self_attn_processor, str): |
| self_attn_processor = find(self_attn_processor) |
| init_adapter_kwargs["self_attn_processor"] = self_attn_processor |
| pipeline.init_custom_adapter(**init_adapter_kwargs) |
|
|
| if self.cfg.pretrained_adapter_name_or_path: |
| pretrained_path = os.path.dirname(self.cfg.pretrained_adapter_name_or_path) |
| adapter_name = os.path.basename(self.cfg.pretrained_adapter_name_or_path) |
| pipeline.load_custom_adapter(pretrained_path, weight_name=adapter_name) |
|
|
| noise_scheduler = DDPMScheduler.from_config( |
| pipeline.scheduler.config, **self.cfg.noise_scheduler_kwargs |
| ) |
| if self.cfg.shift_noise: |
| noise_scheduler = ShiftSNRScheduler.from_scheduler( |
| noise_scheduler, |
| shift_mode=self.cfg.shift_noise_mode, |
| shift_scale=self.cfg.shift_noise_scale, |
| scheduler_class=DDPMScheduler, |
| ) |
| pipeline.scheduler = noise_scheduler |
|
|
| |
| self.pipeline: IG2MVSDXLPipeline = pipeline |
| self.vae = self.pipeline.vae.to( |
| dtype=torch.float16 if self.cfg.use_fp16_vae else torch.float32 |
| ) |
| self.tokenizer = self.pipeline.tokenizer |
| self.tokenizer_2 = self.pipeline.tokenizer_2 |
| self.text_encoder = self.pipeline.text_encoder.to( |
| dtype=torch.float16 if self.cfg.use_fp16_clip else torch.float32 |
| ) |
| self.text_encoder_2 = self.pipeline.text_encoder_2.to( |
| dtype=torch.float16 if self.cfg.use_fp16_clip else torch.float32 |
| ) |
| self.feature_extractor = self.pipeline.feature_extractor |
|
|
| self.cond_encoder = self.pipeline.cond_encoder |
| self.unet = self.pipeline.unet |
| self.noise_scheduler = self.pipeline.scheduler |
| self.inference_scheduler = DDPMScheduler.from_config( |
| self.noise_scheduler.config |
| ) |
| self.pipeline.scheduler = self.inference_scheduler |
| if self.cfg.prediction_type is not None: |
| self.noise_scheduler.register_to_config( |
| prediction_type=self.cfg.prediction_type |
| ) |
|
|
| |
| trainable_modules = self.cfg.trainable_modules |
| if trainable_modules and len(trainable_modules) > 0: |
| self.unet.requires_grad_(False) |
| for name, module in self.unet.named_modules(): |
| for trainable_module in trainable_modules: |
| if trainable_module in name: |
| module.requires_grad_(True) |
| else: |
| self.unet.requires_grad_(True) |
| self.cond_encoder.requires_grad_(self.cfg.train_cond_encoder) |
|
|
| self.vae.requires_grad_(False) |
| self.text_encoder.requires_grad_(False) |
| self.text_encoder_2.requires_grad_(False) |
|
|
| |
| |
| if self.cfg.gradient_checkpointing: |
| self.unet.enable_gradient_checkpointing() |
|
|
| def forward( |
| self, |
| noisy_latents: Tensor, |
| conditioning_pixel_values: Tensor, |
| timesteps: Tensor, |
| ref_latents: Tensor, |
| prompts: List[str], |
| num_views: int, |
| **kwargs, |
| ) -> Dict[str, Any]: |
| bsz = noisy_latents.shape[0] |
| b_samples = bsz // num_views |
| num_batch_images = num_views |
|
|
| prompt_drop_mask = ( |
| torch.rand(b_samples, device=noisy_latents.device) |
| < self.cfg.prompt_drop_prob |
| ) |
| image_drop_mask = ( |
| torch.rand(b_samples, device=noisy_latents.device) |
| < self.cfg.image_drop_prob |
| ) |
| cond_drop_mask = ( |
| torch.rand(b_samples, device=noisy_latents.device) < self.cfg.cond_drop_prob |
| ) |
| prompt_drop_mask = prompt_drop_mask | cond_drop_mask |
| image_drop_mask = image_drop_mask | cond_drop_mask |
|
|
| with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): |
| |
| |
| additional_embeds = compute_embeddings( |
| prompts, |
| prompt_drop_mask, |
| [self.text_encoder, self.text_encoder_2], |
| [self.tokenizer, self.tokenizer_2], |
| **kwargs, |
| ) |
|
|
| |
| with torch.no_grad(): |
| ref_timesteps = torch.zeros_like(timesteps[:b_samples]) |
| ref_hidden_states = {} |
| self.unet( |
| ref_latents, |
| ref_timesteps, |
| encoder_hidden_states=additional_embeds["prompt_embeds"], |
| added_cond_kwargs={ |
| "text_embeds": additional_embeds["text_embeds"], |
| "time_ids": additional_embeds["time_ids"], |
| }, |
| cross_attention_kwargs={ |
| "cache_hidden_states": ref_hidden_states, |
| "use_mv": False, |
| "use_ref": False, |
| }, |
| return_dict=False, |
| ) |
| for k, v in ref_hidden_states.items(): |
| v_ = v |
| v_[image_drop_mask] = 0.0 |
| ref_hidden_states[k] = v_.repeat_interleave(num_batch_images, dim=0) |
|
|
| |
| for key, value in additional_embeds.items(): |
| kwargs[key] = value.repeat_interleave(num_batch_images, dim=0) |
|
|
| conditioning_features = self.cond_encoder(conditioning_pixel_values) |
|
|
| added_cond_kwargs = { |
| "text_embeds": kwargs["text_embeds"], |
| "time_ids": kwargs["time_ids"], |
| } |
|
|
| noise_pred = self.unet( |
| noisy_latents, |
| timesteps, |
| encoder_hidden_states=kwargs["prompt_embeds"], |
| added_cond_kwargs=added_cond_kwargs, |
| down_intrablock_additional_residuals=conditioning_features, |
| cross_attention_kwargs={ |
| "ref_hidden_states": ref_hidden_states, |
| "num_views": num_views, |
| }, |
| ).sample |
|
|
| return {"noise_pred": noise_pred} |
|
|
| def training_step(self, batch, batch_idx): |
| num_views = batch["num_views"] |
|
|
| vae_max_slice = 8 |
| with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): |
| latents = [] |
| for i in range(0, batch["rgb"].shape[0], vae_max_slice): |
| latents.append( |
| vae_encode( |
| self.vae, |
| batch["rgb"][i : i + vae_max_slice].to(self.vae.dtype) * 2 - 1, |
| sample=True, |
| apply_scale=True, |
| ).float() |
| ) |
| latents = torch.cat(latents, dim=0) |
|
|
| with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): |
| ref_latents = vae_encode( |
| self.vae, |
| batch["reference_rgb"].to(self.vae.dtype) * 2 - 1, |
| sample=True, |
| apply_scale=True, |
| ).float() |
|
|
| bsz = latents.shape[0] |
| b_samples = bsz // num_views |
|
|
| noise = torch.randn_like(latents) |
| if self.cfg.noise_offset is not None: |
| |
| noise += self.cfg.noise_offset * torch.randn( |
| (latents.shape[0], latents.shape[1], 1, 1), device=latents.device |
| ) |
|
|
| noise_mask = ( |
| batch["noise_mask"] |
| if "noise_mask" in batch |
| else torch.ones((bsz,), dtype=torch.bool, device=latents.device) |
| ) |
| timesteps = torch.randint( |
| 0, |
| self.noise_scheduler.config.num_train_timesteps, |
| (b_samples,), |
| device=latents.device, |
| dtype=torch.long, |
| ) |
| timesteps = timesteps.repeat_interleave(num_views) |
| timesteps[~noise_mask] = 0 |
|
|
| if self.cfg.input_perturbation is not None: |
| new_noise = noise + self.cfg.input_perturbation * torch.randn_like(noise) |
| noisy_latents = self.noise_scheduler.add_noise( |
| latents, new_noise, timesteps |
| ) |
| else: |
| noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) |
|
|
| noisy_latents[~noise_mask] = latents[~noise_mask] |
|
|
| if self.noise_scheduler.config.prediction_type == "epsilon": |
| target = noise |
| elif self.noise_scheduler.config.prediction_type == "v_prediction": |
| target = self.noise_scheduler.get_velocity(latents, noise, timesteps) |
| else: |
| raise ValueError( |
| f"Unsupported prediction type {self.noise_scheduler.config.prediction_type}" |
| ) |
|
|
| model_pred = self( |
| noisy_latents, batch["source_rgb"], timesteps, ref_latents, **batch |
| )["noise_pred"] |
|
|
| model_pred = model_pred[noise_mask] |
| target = target[noise_mask] |
|
|
| if self.cfg.snr_gamma is None: |
| loss = F.mse_loss(model_pred, target, reduction="mean") |
| else: |
| |
| |
| |
| snr = compute_snr(self.noise_scheduler, timesteps) |
| if self.noise_scheduler.config.prediction_type == "v_prediction": |
| |
| snr = snr + 1 |
| mse_loss_weights = ( |
| torch.stack( |
| [snr, self.cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 |
| ).min(dim=1)[0] |
| / snr |
| ) |
|
|
| loss = F.mse_loss(model_pred, target, reduction="none") |
| loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights |
| loss = loss.mean() |
|
|
| self.log("train/loss", loss, prog_bar=True) |
|
|
| |
| self.check_train(batch) |
|
|
| return {"loss": loss} |
|
|
| def on_train_batch_end(self, outputs, batch, batch_idx): |
| pass |
|
|
| def get_input_visualizations(self, batch): |
| return [ |
| { |
| "type": "rgb", |
| "img": rearrange( |
| batch["source_rgb"], |
| "(B N) C H W -> (B H) (N W) C", |
| N=batch["num_views"], |
| ), |
| "kwargs": {"data_format": "HWC"}, |
| }, |
| { |
| "type": "rgb", |
| "img": rearrange(batch["reference_rgb"], "B C H W -> (B H) W C"), |
| "kwargs": {"data_format": "HWC"}, |
| }, |
| { |
| "type": "rgb", |
| "img": rearrange( |
| batch["rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] |
| ), |
| "kwargs": {"data_format": "HWC"}, |
| }, |
| ] |
|
|
| def get_output_visualizations(self, batch, outputs): |
| images = [ |
| { |
| "type": "rgb", |
| "img": rearrange( |
| batch["source_rgb"], |
| "(B N) C H W -> (B H) (N W) C", |
| N=batch["num_views"], |
| ), |
| "kwargs": {"data_format": "HWC"}, |
| }, |
| { |
| "type": "rgb", |
| "img": rearrange( |
| batch["rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] |
| ), |
| "kwargs": {"data_format": "HWC"}, |
| }, |
| { |
| "type": "rgb", |
| "img": rearrange(batch["reference_rgb"], "B C H W -> (B H) W C"), |
| "kwargs": {"data_format": "HWC"}, |
| }, |
| { |
| "type": "rgb", |
| "img": rearrange( |
| outputs, "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] |
| ), |
| "kwargs": {"data_format": "HWC"}, |
| }, |
| ] |
| return images |
|
|
| def generate_images(self, batch, **kwargs): |
| return self.pipeline( |
| prompt=batch["prompts"], |
| control_image=batch["source_rgb"], |
| num_images_per_prompt=batch["num_views"], |
| generator=torch.Generator(device=self.device).manual_seed( |
| self.cfg.eval_seed |
| ), |
| num_inference_steps=self.cfg.eval_num_inference_steps, |
| guidance_scale=self.cfg.eval_guidance_scale, |
| height=self.cfg.eval_height, |
| width=self.cfg.eval_width, |
| reference_image=batch["reference_rgb"], |
| output_type="pt", |
| ).images |
|
|
| def on_save_checkpoint(self, checkpoint): |
| if self.global_rank == 0: |
| self.pipeline.save_custom_adapter( |
| os.path.dirname(self.get_save_dir()), |
| "step1x-3d-ig2v.safetensors", |
| safe_serialization=True, |
| include_keys=self.cfg.trainable_modules, |
| ) |
|
|
| def on_check_train(self, batch): |
| self.save_image_grid( |
| f"it{self.true_global_step}-train.jpg", |
| self.get_input_visualizations(batch), |
| name="train_step_input", |
| step=self.true_global_step, |
| ) |
|
|
| def validation_step(self, batch, batch_idx): |
| out = self.generate_images(batch) |
|
|
| if ( |
| self.cfg.check_val_limit_rank > 0 |
| and self.global_rank < self.cfg.check_val_limit_rank |
| ): |
| self.save_image_grid( |
| f"it{self.true_global_step}-validation-{self.global_rank}_{batch_idx}.jpg", |
| self.get_output_visualizations(batch, out), |
| name=f"validation_step_output_{self.global_rank}_{batch_idx}", |
| step=self.true_global_step, |
| ) |
|
|
| def on_validation_epoch_end(self): |
| pass |
|
|
| def test_step(self, batch, batch_idx): |
| out = self.generate_images(batch) |
|
|
| self.save_image_grid( |
| f"it{self.true_global_step}-test-{self.global_rank}_{batch_idx}.jpg", |
| self.get_output_visualizations(batch, out), |
| name=f"test_step_output_{self.global_rank}_{batch_idx}", |
| step=self.true_global_step, |
| ) |
|
|
| def on_test_end(self): |
| pass |
|
|