Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) 2023, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import logging | |
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| import tqdm | |
| from diffusers import ( | |
| AutoencoderKL, | |
| ControlNetModel, | |
| DDPMScheduler, | |
| DDIMScheduler, | |
| PNDMScheduler, | |
| UNet2DConditionModel, | |
| ) | |
| from torch import nn | |
| from transformers import CLIPTokenizer | |
| from transformers.activations import QuickGELUActivation as QuickGELU | |
| from lavis.common.registry import registry | |
| from lavis.common.utils import download_and_untar, is_url | |
| from lavis.models.base_model import BaseModel | |
| from lavis.models.blip2_models.blip2_qformer import Blip2Qformer | |
| from lavis.models.blip_diffusion_models.modeling_ctx_clip import CtxCLIPTextModel | |
| from lavis.models.blip_diffusion_models.utils import numpy_to_pil, prepare_cond_image | |
| from lavis.models.blip_diffusion_models.ptp_utils import ( | |
| LocalBlend, | |
| P2PCrossAttnProcessor, | |
| AttentionRefine, | |
| ) | |
| class ProjLayer(nn.Module): | |
| def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12): | |
| super().__init__() | |
| # Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm | |
| self.dense1 = nn.Linear(in_dim, hidden_dim) | |
| self.act_fn = QuickGELU() | |
| self.dense2 = nn.Linear(hidden_dim, out_dim) | |
| self.dropout = nn.Dropout(drop_p) | |
| self.LayerNorm = nn.LayerNorm(out_dim, eps=eps) | |
| def forward(self, x): | |
| x_in = x | |
| x = self.LayerNorm(x) | |
| x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in | |
| return x | |
| class BlipDiffusion(BaseModel): | |
| PRETRAINED_MODEL_CONFIG_DICT = { | |
| "base": "configs/models/blip-diffusion/blip_diffusion_base.yaml", | |
| "canny": "configs/models/blip-diffusion/blip_diffusion_controlnet_canny.yaml", | |
| "depth": "configs/models/blip-diffusion/blip_diffusion_controlnet_depth.yaml", | |
| "hed": "configs/models/blip-diffusion/blip_diffusion_controlnet_hed.yaml", | |
| } | |
| def __init__( | |
| self, | |
| vit_model="clip_L", | |
| qformer_num_query_token=16, | |
| qformer_cross_attention_freq=1, | |
| qformer_pretrained_path=None, | |
| qformer_train=False, | |
| sd_pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", | |
| sd_train_text_encoder=False, | |
| controlnet_pretrained_model_name_or_path=None, | |
| vae_half_precision=False, | |
| proj_train=False, | |
| ): | |
| super().__init__() | |
| self.num_query_token = qformer_num_query_token | |
| # BLIP-2 | |
| self.blip = Blip2Qformer( | |
| vit_model=vit_model, | |
| num_query_token=qformer_num_query_token, | |
| cross_attention_freq=qformer_cross_attention_freq, | |
| ) | |
| if qformer_pretrained_path is not None: | |
| state_dict = torch.load(qformer_pretrained_path, map_location="cpu")[ | |
| "model" | |
| ] | |
| # qformer keys: Qformer.bert.encoder.layer.1.attention.self.key.weight | |
| # ckpt keys: text_model.bert.encoder.layer.1.attention.self.key.weight | |
| for k in list(state_dict.keys()): | |
| if "text_model" in k: | |
| state_dict[k.replace("text_model", "Qformer")] = state_dict.pop(k) | |
| msg = self.blip.load_state_dict(state_dict, strict=False) | |
| assert all(["visual" in k for k in msg.missing_keys]) | |
| assert len(msg.unexpected_keys) == 0 | |
| self.qformer_train = qformer_train | |
| # projection layer | |
| self.proj_layer = ProjLayer( | |
| in_dim=768, out_dim=768, hidden_dim=3072, drop_p=0.1, eps=1e-12 | |
| ) | |
| self.proj_train = proj_train | |
| # stable diffusion | |
| self.tokenizer = CLIPTokenizer.from_pretrained( | |
| sd_pretrained_model_name_or_path, subfolder="tokenizer" | |
| ) | |
| self.text_encoder = CtxCLIPTextModel.from_pretrained( | |
| sd_pretrained_model_name_or_path, subfolder="text_encoder" | |
| ) | |
| self.vae = AutoencoderKL.from_pretrained( | |
| sd_pretrained_model_name_or_path, subfolder="vae" | |
| ) | |
| if vae_half_precision: | |
| self.vae.half() | |
| self.unet = UNet2DConditionModel.from_pretrained( | |
| sd_pretrained_model_name_or_path, subfolder="unet" | |
| ) | |
| # self.unet.enable_xformers_memory_efficient_attention() | |
| self.noise_scheduler = DDPMScheduler.from_config( | |
| sd_pretrained_model_name_or_path, subfolder="scheduler" | |
| ) | |
| self.sd_train_text_encoder = sd_train_text_encoder | |
| if controlnet_pretrained_model_name_or_path is not None: | |
| self.controlnet = ControlNetModel.from_pretrained( | |
| controlnet_pretrained_model_name_or_path | |
| ) | |
| self.freeze_modules() | |
| self.ctx_embeddings_cache = nn.Parameter( | |
| torch.zeros(1, self.num_query_token, 768), requires_grad=False | |
| ) | |
| self._use_embeddings_cache = False | |
| # inference-related | |
| self._CTX_BEGIN_POS = 2 | |
| def freeze_modules(self): | |
| to_freeze = [self.vae] | |
| if not self.sd_train_text_encoder: | |
| to_freeze.append(self.text_encoder) | |
| if not self.qformer_train: | |
| to_freeze.append(self.blip) | |
| if not self.proj_train: | |
| to_freeze.append(self.proj_layer) | |
| for module in to_freeze: | |
| module.eval() | |
| module.train = self.disabled_train | |
| module.requires_grad_(False) | |
| def disabled_train(self, mode=True): | |
| """Overwrite model.train with this function to make sure train/eval mode | |
| does not change anymore.""" | |
| return self | |
| def pndm_scheduler(self): | |
| if not hasattr(self, "_pndm_scheduler"): | |
| self._pndm_scheduler = PNDMScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| set_alpha_to_one=False, | |
| skip_prk_steps=True, | |
| ) | |
| return self._pndm_scheduler | |
| def ddim_scheduler(self): | |
| if not hasattr(self, "_ddim_scheduler"): | |
| self._ddim_scheduler = DDIMScheduler.from_config( | |
| "runwayml/stable-diffusion-v1-5", subfolder="scheduler" | |
| ) | |
| return self._ddim_scheduler | |
| def before_training(self, dataset, **kwargs): | |
| assert len(dataset) == 1, "Only support single dataset for now." | |
| key = list(dataset.keys())[0] | |
| dataset = dataset[key]["train"] | |
| # collect all examples | |
| # [FIXME] this is not memory efficient. may OOM if the dataset is large. | |
| examples = [dataset[i] for i in range(dataset.len_without_repeat)] | |
| input_images = ( | |
| torch.stack([example["inp_image"] for example in examples]) | |
| .to(memory_format=torch.contiguous_format) | |
| .float() | |
| ).to(self.device) | |
| subject_text = [dataset.subject for _ in range(input_images.shape[0])] | |
| # calculate ctx embeddings and cache them | |
| ctx_embeddings = self.forward_ctx_embeddings( | |
| input_image=input_images, text_input=subject_text | |
| ) | |
| # take mean of all ctx embeddings | |
| ctx_embeddings = ctx_embeddings.mean(dim=0, keepdim=True) | |
| self.ctx_embeddings_cache = nn.Parameter(ctx_embeddings, requires_grad=True) | |
| self._use_embeddings_cache = True | |
| # free up CUDA memory | |
| self.blip.to("cpu") | |
| self.proj_layer.to("cpu") | |
| torch.cuda.empty_cache() | |
| def forward(self, samples): | |
| latents = self.vae.encode(samples["tgt_image"].half()).latent_dist.sample() | |
| latents = latents * 0.18215 | |
| # Sample noise that we'll add to the latents | |
| noise = torch.randn_like(latents) | |
| bsz = latents.shape[0] | |
| # Sample a random timestep for each image | |
| timesteps = torch.randint( | |
| 0, | |
| self.noise_scheduler.config.num_train_timesteps, | |
| (bsz,), | |
| device=latents.device, | |
| ) | |
| timesteps = timesteps.long() | |
| # Add noise to the latents according to the noise magnitude at each timestep | |
| # (this is the forward diffusion process) | |
| noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) | |
| ctx_embeddings = self.forward_ctx_embeddings( | |
| input_image=samples["inp_image"], text_input=samples["subject_text"] | |
| ) | |
| # Get the text embedding for conditioning | |
| input_ids = self.tokenizer( | |
| samples["caption"], | |
| padding="do_not_pad", | |
| truncation=True, | |
| max_length=self.tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ).input_ids.to(self.device) | |
| encoder_hidden_states = self.text_encoder( | |
| input_ids=input_ids, | |
| ctx_embeddings=ctx_embeddings, | |
| ctx_begin_pos=[self._CTX_BEGIN_POS] * input_ids.shape[0], | |
| )[0] | |
| # Predict the noise residual | |
| noise_pred = self.unet( | |
| noisy_latents.float(), timesteps, encoder_hidden_states | |
| ).sample | |
| loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | |
| return {"loss": loss} | |
| def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): | |
| rv = [] | |
| for prompt, tgt_subject in zip(prompts, tgt_subjects): | |
| prompt = f"a {tgt_subject} {prompt.strip()}" | |
| # a trick to amplify the prompt | |
| rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps))) | |
| return rv | |
| def _build_prompts_edit(self, cond_subject, tgt_subject, prompt): | |
| placeholder = " ".join(["sks"] * self.num_query_token) | |
| src_prompt = f"a {cond_subject} {prompt}" | |
| tgt_prompt = f"a {placeholder} {tgt_subject} {prompt}" | |
| return [src_prompt, tgt_prompt] | |
| def _predict_noise( | |
| self, | |
| t, | |
| latent_model_input, | |
| text_embeddings, | |
| width=512, | |
| height=512, | |
| cond_image=None, | |
| ): | |
| if hasattr(self, "controlnet"): | |
| cond_image = prepare_cond_image( | |
| cond_image, width, height, batch_size=1, device=self.device | |
| ) | |
| down_block_res_samples, mid_block_res_sample = self.controlnet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=text_embeddings, | |
| controlnet_cond=cond_image, | |
| # conditioning_scale=controlnet_condition_scale, | |
| return_dict=False, | |
| ) | |
| else: | |
| down_block_res_samples, mid_block_res_sample = None, None | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| timestep=t, | |
| encoder_hidden_states=text_embeddings, | |
| down_block_additional_residuals=down_block_res_samples, | |
| mid_block_additional_residual=mid_block_res_sample, | |
| )["sample"] | |
| return noise_pred | |
| def _init_latent(self, latent, height, width, generator, batch_size): | |
| if latent is None: | |
| latent = torch.randn( | |
| (1, self.unet.in_channels, height // 8, width // 8), | |
| generator=generator, | |
| device=generator.device, | |
| ) | |
| latent = latent.expand( | |
| batch_size, | |
| self.unet.in_channels, | |
| height // 8, | |
| width // 8, | |
| ) | |
| return latent.to(self.device) | |
| def _forward_prompt_embeddings(self, input_image, src_subject, prompt): | |
| # 1. extract BLIP query features and proj to text space -> (bs, 32, 768) | |
| query_embeds = self.forward_ctx_embeddings(input_image, src_subject) | |
| # 2. embeddings for prompt, with query_embeds as context | |
| tokenized_prompt = self._tokenize_text(prompt).to(self.device) | |
| text_embeddings = self.text_encoder( | |
| input_ids=tokenized_prompt.input_ids, | |
| ctx_embeddings=query_embeds, | |
| ctx_begin_pos=[self._CTX_BEGIN_POS], | |
| )[0] | |
| return text_embeddings | |
| def get_image_latents(self, image, sample=True, rng_generator=None): | |
| assert isinstance(image, torch.Tensor) | |
| encoding_dist = self.vae.encode(image).latent_dist | |
| if sample: | |
| encoding = encoding_dist.sample(generator=rng_generator) | |
| else: | |
| encoding = encoding_dist.mode() | |
| latents = encoding * 0.18215 | |
| return latents | |
| def _inversion_transform(self, image, target_size=512): | |
| from torchvision import transforms | |
| tform = transforms.Compose( | |
| [ | |
| transforms.Resize(target_size), | |
| transforms.CenterCrop(target_size), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| image = tform(image).unsqueeze(0).to(self.device) | |
| return 2.0 * image - 1.0 | |
| def edit( | |
| self, | |
| samples, | |
| lb_threshold=0.3, | |
| guidance_scale=7.5, | |
| height=512, | |
| width=512, | |
| seed=42, | |
| num_inference_steps=50, | |
| num_inversion_steps=50, | |
| neg_prompt="", | |
| ): | |
| raw_image = samples["raw_image"] | |
| raw_image = self._inversion_transform(raw_image) | |
| latents = self.get_image_latents(raw_image, rng_generator=None) | |
| inv_latents = self._ddim_inverse( | |
| samples=samples, | |
| latents=latents, | |
| seed=seed, | |
| guidance_scale=1.0, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inversion_steps, | |
| ) | |
| recon_image = self.generate_then_edit( | |
| samples=samples, | |
| latents=inv_latents, | |
| seed=seed, | |
| neg_prompt=neg_prompt, | |
| guidance_scale=guidance_scale, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| use_inversion=True, | |
| lb_threshold=lb_threshold, | |
| ) | |
| return recon_image | |
| def _ddim_inverse( | |
| self, | |
| samples, | |
| latents, | |
| guidance_scale=1.0, | |
| height=512, | |
| width=512, | |
| seed=42, | |
| num_inference_steps=50, | |
| ): | |
| src_subject = samples["src_subject"] # source subject category | |
| prompt = samples["prompt"] | |
| prompt = self._build_prompt( | |
| prompts=prompt, | |
| tgt_subjects=src_subject, | |
| prompt_strength=1.0, | |
| prompt_reps=1, | |
| ) | |
| tokenized_prompt = self._tokenize_text(prompt, with_query=False).to(self.device) | |
| text_embeddings = self.text_encoder( | |
| input_ids=tokenized_prompt.input_ids, | |
| ctx_embeddings=None, | |
| )[0] | |
| if seed is not None: | |
| generator = torch.Generator(device=self.device) | |
| generator = generator.manual_seed(seed) | |
| latents = self._init_latent(latents, height, width, generator, batch_size=1) | |
| scheduler = self.ddim_scheduler | |
| # set timesteps | |
| extra_set_kwargs = {} | |
| scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) | |
| iterator = tqdm.tqdm(reversed(scheduler.timesteps)) | |
| for i, t in enumerate(iterator): | |
| latents = self._noise_latent_step( | |
| latents=latents, | |
| t=t, | |
| text_embeddings=text_embeddings, | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance_scale, | |
| ) | |
| return latents | |
| def generate( | |
| self, | |
| samples, | |
| latents=None, | |
| guidance_scale=7.5, | |
| height=512, | |
| width=512, | |
| seed=42, | |
| num_inference_steps=50, | |
| neg_prompt="", | |
| controller=None, | |
| prompt_strength=1.0, | |
| prompt_reps=20, | |
| use_ddim=False, | |
| ): | |
| if controller is not None: | |
| self._register_attention_refine(controller) | |
| cond_image = samples["cond_images"] # reference image | |
| cond_subject = samples["cond_subject"] # source subject category | |
| tgt_subject = samples["tgt_subject"] # target subject category | |
| prompt = samples["prompt"] | |
| cldm_cond_image = samples.get("cldm_cond_image", None) # conditional image | |
| prompt = self._build_prompt( | |
| prompts=prompt, | |
| tgt_subjects=tgt_subject, | |
| prompt_strength=prompt_strength, | |
| prompt_reps=prompt_reps, | |
| ) | |
| text_embeddings = self._forward_prompt_embeddings( | |
| cond_image, cond_subject, prompt | |
| ) | |
| # 3. unconditional embedding | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| if do_classifier_free_guidance: | |
| max_length = self.text_encoder.text_model.config.max_position_embeddings | |
| uncond_input = self.tokenizer( | |
| [neg_prompt], | |
| padding="max_length", | |
| max_length=max_length, | |
| return_tensors="pt", | |
| ) | |
| uncond_embeddings = self.text_encoder( | |
| input_ids=uncond_input.input_ids.to(self.device), | |
| ctx_embeddings=None, | |
| )[0] | |
| # For classifier free guidance, we need to do two forward passes. | |
| # Here we concatenate the unconditional and text embeddings into a single batch | |
| # to avoid doing two forward passes | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| if seed is not None: | |
| generator = torch.Generator(device=self.device) | |
| generator = generator.manual_seed(seed) | |
| latents = self._init_latent(latents, height, width, generator, batch_size=1) | |
| scheduler = self.pndm_scheduler if not use_ddim else self.ddim_scheduler | |
| # set timesteps | |
| extra_set_kwargs = {} | |
| scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) | |
| iterator = tqdm.tqdm(scheduler.timesteps) | |
| for i, t in enumerate(iterator): | |
| latents = self._denoise_latent_step( | |
| latents=latents, | |
| t=t, | |
| text_embeddings=text_embeddings, | |
| cond_image=cldm_cond_image, | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance_scale, | |
| use_inversion=use_ddim, | |
| ) | |
| image = self._latent_to_image(latents) | |
| return image | |
| def _register_attention_refine( | |
| self, | |
| src_subject, | |
| prompts, | |
| num_inference_steps, | |
| cross_replace_steps=0.8, | |
| self_replace_steps=0.4, | |
| threshold=0.3, | |
| ): | |
| device, tokenizer = self.device, self.tokenizer | |
| lb = LocalBlend( | |
| prompts=prompts, | |
| words=(src_subject,), | |
| device=device, | |
| tokenizer=tokenizer, | |
| threshold=threshold, | |
| ) | |
| controller = AttentionRefine( | |
| prompts, | |
| num_inference_steps, | |
| cross_replace_steps=cross_replace_steps, | |
| self_replace_steps=self_replace_steps, | |
| tokenizer=tokenizer, | |
| device=device, | |
| local_blend=lb, | |
| ) | |
| self._register_attention_control(controller) | |
| return controller | |
| def _register_attention_control(self, controller): | |
| attn_procs = {} | |
| cross_att_count = 0 | |
| for name in self.unet.attn_processors.keys(): | |
| cross_attention_dim = ( | |
| None | |
| if name.endswith("attn1.processor") | |
| else self.unet.config.cross_attention_dim | |
| ) | |
| if name.startswith("mid_block"): | |
| hidden_size = self.unet.config.block_out_channels[-1] | |
| place_in_unet = "mid" | |
| elif name.startswith("up_blocks"): | |
| block_id = int(name[len("up_blocks.")]) | |
| hidden_size = list(reversed(self.unet.config.block_out_channels))[ | |
| block_id | |
| ] | |
| place_in_unet = "up" | |
| elif name.startswith("down_blocks"): | |
| block_id = int(name[len("down_blocks.")]) | |
| hidden_size = self.unet.config.block_out_channels[block_id] | |
| place_in_unet = "down" | |
| else: | |
| continue | |
| cross_att_count += 1 | |
| attn_procs[name] = P2PCrossAttnProcessor( | |
| controller=controller, place_in_unet=place_in_unet | |
| ) | |
| self.unet.set_attn_processor(attn_procs) | |
| if controller is not None: | |
| controller.num_att_layers = cross_att_count | |
| def generate_then_edit( | |
| self, | |
| samples, | |
| cross_replace_steps=0.8, | |
| self_replace_steps=0.4, | |
| guidance_scale=7.5, | |
| height=512, | |
| width=512, | |
| latents=None, | |
| seed=42, | |
| num_inference_steps=250, | |
| neg_prompt="", | |
| use_inversion=False, | |
| lb_threshold=0.3, | |
| ): | |
| cond_image = samples["cond_images"] # reference image | |
| cond_subject = samples["cond_subject"] # source subject category | |
| src_subject = samples["src_subject"] | |
| tgt_subject = samples["tgt_subject"] # target subject category | |
| prompt = samples["prompt"] | |
| assert len(prompt) == 1, "Do not support multiple prompts for now" | |
| prompt = self._build_prompts_edit(src_subject, tgt_subject, prompt[0]) | |
| print(prompt) | |
| controller = self._register_attention_refine( | |
| src_subject=src_subject, | |
| prompts=prompt, | |
| num_inference_steps=num_inference_steps, | |
| cross_replace_steps=cross_replace_steps, | |
| self_replace_steps=self_replace_steps, | |
| threshold=lb_threshold, | |
| ) | |
| query_embeds = self.forward_ctx_embeddings(cond_image, cond_subject) | |
| tokenized_prompt_bef = self._tokenize_text(prompt[:1], with_query=False).to( | |
| self.device | |
| ) | |
| tokenized_prompt_aft = self._tokenize_text(prompt[1:], with_query=True).to( | |
| self.device | |
| ) | |
| text_embeddings_bef = self.text_encoder( | |
| input_ids=tokenized_prompt_bef.input_ids, | |
| )[0] | |
| text_embeddings_aft = self.text_encoder( | |
| input_ids=tokenized_prompt_aft.input_ids, | |
| ctx_embeddings=query_embeds, | |
| ctx_begin_pos=[self._CTX_BEGIN_POS], | |
| )[0] | |
| text_embeddings = torch.cat([text_embeddings_bef, text_embeddings_aft], dim=0) | |
| # 3. unconditional embedding | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| # [TODO] add support for batched input | |
| batch_size = 2 | |
| if do_classifier_free_guidance: | |
| max_length = self.text_encoder.text_model.config.max_position_embeddings | |
| uncond_input = self.tokenizer( | |
| [neg_prompt], | |
| padding="max_length", | |
| max_length=max_length, | |
| return_tensors="pt", | |
| ) | |
| # FIXME use context embedding for uncond_input or not? | |
| uncond_embeddings = self.text_encoder( | |
| input_ids=uncond_input.input_ids.to(self.device), | |
| ctx_embeddings=None, | |
| )[0] | |
| # repeat the uncond embedding to match the number of prompts | |
| uncond_embeddings = uncond_embeddings.expand(batch_size, -1, -1) | |
| # For classifier free guidance, we need to do two forward passes. | |
| # Here we concatenate the unconditional and text embeddings into a single batch | |
| # to avoid doing two forward passes | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| if seed is not None: | |
| generator = torch.Generator(device=self.device) | |
| generator = generator.manual_seed(seed) | |
| latents = self._init_latent(latents, height, width, generator, batch_size) | |
| scheduler = self.pndm_scheduler if not use_inversion else self.ddim_scheduler | |
| # set timesteps | |
| scheduler.set_timesteps(num_inference_steps) | |
| iterator = tqdm.tqdm(scheduler.timesteps) | |
| for i, t in enumerate(iterator): | |
| latents = self._denoise_latent_step( | |
| latents=latents, | |
| t=t, | |
| text_embeddings=text_embeddings, | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance_scale, | |
| use_inversion=use_inversion, | |
| ) | |
| latents = controller.step_callback(latents) | |
| image = self._latent_to_image(latents) | |
| controller.reset() | |
| return image | |
| def _latent_to_image(self, latents): | |
| latents = 1 / 0.18215 * latents | |
| image = self.vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).numpy() | |
| image = numpy_to_pil(image) | |
| return image | |
| def _noise_latent_step( | |
| self, | |
| latents, | |
| t, | |
| text_embeddings, | |
| guidance_scale, | |
| height, | |
| width, | |
| ): | |
| def backward_ddim(x_t, alpha_t, alpha_tm1, eps_xt): | |
| """from noise to image""" | |
| return ( | |
| alpha_tm1**0.5 | |
| * ( | |
| (alpha_t**-0.5 - alpha_tm1**-0.5) * x_t | |
| + ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt | |
| ) | |
| + x_t | |
| ) | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| latent_model_input = ( | |
| torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
| ) | |
| # predict the noise residual | |
| noise_pred = self._predict_noise( | |
| t=t, | |
| latent_model_input=latent_model_input, | |
| text_embeddings=text_embeddings, | |
| width=width, | |
| height=height, | |
| ) | |
| scheduler = self.ddim_scheduler | |
| prev_timestep = ( | |
| t - scheduler.config.num_train_timesteps // scheduler.num_inference_steps | |
| ) | |
| alpha_prod_t = scheduler.alphas_cumprod[t] | |
| alpha_prod_t_prev = ( | |
| scheduler.alphas_cumprod[prev_timestep] | |
| if prev_timestep >= 0 | |
| else scheduler.final_alpha_cumprod | |
| ) | |
| alpha_prod_t, alpha_prod_t_prev = alpha_prod_t_prev, alpha_prod_t | |
| latents = backward_ddim( | |
| x_t=latents, | |
| alpha_t=alpha_prod_t, | |
| alpha_tm1=alpha_prod_t_prev, | |
| eps_xt=noise_pred, | |
| ) | |
| return latents | |
| def _denoise_latent_step( | |
| self, | |
| latents, | |
| t, | |
| text_embeddings, | |
| guidance_scale, | |
| height, | |
| width, | |
| cond_image=None, | |
| use_inversion=False, | |
| ): | |
| if use_inversion: | |
| noise_placeholder = [] | |
| # expand the latents if we are doing classifier free guidance | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| latent_model_input = ( | |
| torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
| ) | |
| # predict the noise residual | |
| noise_pred = self._predict_noise( | |
| t=t, | |
| latent_model_input=latent_model_input, | |
| text_embeddings=text_embeddings, | |
| width=width, | |
| height=height, | |
| cond_image=cond_image, | |
| ) | |
| if use_inversion: | |
| noise_placeholder.append(noise_pred[2].unsqueeze(0)) | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| if use_inversion: | |
| noise_placeholder.append(noise_pred[-1].unsqueeze(0)) | |
| noise_pred = torch.cat(noise_placeholder) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| scheduler = self.ddim_scheduler if use_inversion else self.pndm_scheduler | |
| latents = scheduler.step( | |
| noise_pred, | |
| t, | |
| latents, | |
| )["prev_sample"] | |
| return latents | |
| def _tokenize_text(self, text_input, with_query=True): | |
| max_len = self.text_encoder.text_model.config.max_position_embeddings | |
| if with_query: | |
| max_len -= self.num_query_token | |
| tokenized_text = self.tokenizer( | |
| text_input, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=max_len, | |
| return_tensors="pt", | |
| ) | |
| return tokenized_text | |
| def forward_ctx_embeddings(self, input_image, text_input, ratio=None): | |
| def compute_ctx_embeddings(input_image, text_input): | |
| # blip_embeddings = self.blip(image=input_image, text=text_input) | |
| blip_embeddings = self.blip.extract_features( | |
| {"image": input_image, "text_input": text_input}, mode="multimodal" | |
| ).multimodal_embeds | |
| ctx_embeddings = self.proj_layer(blip_embeddings) | |
| return ctx_embeddings | |
| if isinstance(text_input, str): | |
| text_input = [text_input] | |
| if self._use_embeddings_cache: | |
| # expand to batch size | |
| ctx_embeddings = self.ctx_embeddings_cache.expand(len(text_input), -1, -1) | |
| else: | |
| if isinstance(text_input[0], str): | |
| text_input, input_image = [text_input], [input_image] | |
| all_ctx_embeddings = [] | |
| for inp_image, inp_text in zip(input_image, text_input): | |
| ctx_embeddings = compute_ctx_embeddings(inp_image, inp_text) | |
| all_ctx_embeddings.append(ctx_embeddings) | |
| if ratio is not None: | |
| assert len(ratio) == len(all_ctx_embeddings) | |
| assert sum(ratio) == 1 | |
| else: | |
| ratio = [1 / len(all_ctx_embeddings)] * len(all_ctx_embeddings) | |
| ctx_embeddings = torch.zeros_like(all_ctx_embeddings[0]) | |
| for ratio, ctx_embeddings_ in zip(ratio, all_ctx_embeddings): | |
| ctx_embeddings += ratio * ctx_embeddings_ | |
| return ctx_embeddings | |
| def from_config(cls, cfg): | |
| vit_model = cfg.get("vit_model", "clip_L") | |
| qformer_cross_attention_freq = cfg.get("qformer_cross_attention_freq", 1) | |
| qformer_num_query_token = cfg.get("qformer_num_query_token", 16) | |
| qformer_train = cfg.get("qformer_train", False) | |
| sd_train_text_encoder = cfg.get("sd_train_text_encoder", False) | |
| sd_pretrained_model_name_or_path = cfg.get( | |
| "sd_pretrained_model_name_or_path", "runwayml/stable-diffusion-v1-5" | |
| ) | |
| controlnet_pretrained_model_name_or_path = cfg.get( | |
| "controlnet_pretrained_model_name_or_path", None | |
| ) | |
| vae_half_precision = cfg.get("vae_half_precision", False) | |
| model = cls( | |
| vit_model=vit_model, | |
| qformer_cross_attention_freq=qformer_cross_attention_freq, | |
| qformer_num_query_token=qformer_num_query_token, | |
| qformer_train=qformer_train, | |
| sd_train_text_encoder=sd_train_text_encoder, | |
| sd_pretrained_model_name_or_path=sd_pretrained_model_name_or_path, | |
| controlnet_pretrained_model_name_or_path=controlnet_pretrained_model_name_or_path, | |
| vae_half_precision=vae_half_precision, | |
| ) | |
| model.load_checkpoint_from_config(cfg) | |
| return model | |
| def load_checkpoint_from_dir(self, checkpoint_dir_or_url): | |
| # if checkpoint_dir is a url, download it and untar it | |
| if is_url(checkpoint_dir_or_url): | |
| checkpoint_dir_or_url = download_and_untar(checkpoint_dir_or_url) | |
| logging.info(f"Loading pretrained model from {checkpoint_dir_or_url}") | |
| def load_state_dict(module, filename): | |
| try: | |
| state_dict = torch.load( | |
| os.path.join(checkpoint_dir_or_url, filename), map_location="cpu" | |
| ) | |
| msg = module.load_state_dict(state_dict, strict=False) | |
| except FileNotFoundError: | |
| logging.info("File not found, skip loading: {}".format(filename)) | |
| load_state_dict(self.proj_layer, "proj_layer/proj_weight.pt") | |
| load_state_dict(self.blip, "blip_model/blip_weight.pt") | |
| load_state_dict(self.unet, "unet/diffusion_pytorch_model.bin") | |
| load_state_dict(self.vae, "vae/diffusion_pytorch_model.bin") | |
| load_state_dict(self.text_encoder, "text_encoder/pytorch_model.bin") | |
| try: | |
| self.ctx_embeddings_cache.data = torch.load( | |
| os.path.join( | |
| checkpoint_dir_or_url, "ctx_embeddings_cache/ctx_embeddings_cache.pt" | |
| ), | |
| map_location=self.device, | |
| ) | |
| self._use_embeddings_cache = True | |
| print("Loaded ctx_embeddings_cache from {}".format(checkpoint_dir_or_url)) | |
| except FileNotFoundError: | |
| self._use_embeddings_cache = False | |
| print("No ctx_embeddings_cache found in {}".format(checkpoint_dir_or_url)) | |
| def load_from_pretrained(self, url_or_filename): | |
| checkpoint_dir = url_or_filename | |
| self.load_checkpoint_from_dir(checkpoint_dir) | |
| def load_checkpoint(self, url_or_filename): | |
| """ | |
| Used to load finetuned models. | |
| """ | |
| super().load_checkpoint(url_or_filename) | |
| print("loading fine-tuned model from {}".format(url_or_filename)) | |
| self._use_embeddings_cache = True | |