| | from contextlib import nullcontext |
| | from typing import List, Optional, Tuple, Union |
| |
|
| | import torch |
| | from einops import rearrange |
| | from PIL import Image |
| | from tqdm.auto import tqdm |
| |
|
| | from diffusers import DiffusionPipeline |
| | from diffusers.pipelines.pipeline_utils import ImagePipelineOutput |
| |
|
| | from .constants import SUPPORTED_IMAGE_SIZES |
| |
|
| |
|
| | PromptType = Union[str, List[str]] |
| |
|
| |
|
| | def _get_pkv_seq_len(past_key_values) -> int: |
| | """Get cached sequence length from past_key_values (supports tuple and DynamicCache).""" |
| | if hasattr(past_key_values, "get_seq_length"): |
| | return past_key_values.get_seq_length() |
| | return past_key_values[0][0].shape[2] |
| |
|
| |
|
| | class BitDanceDiffusionPipeline(DiffusionPipeline): |
| | model_cpu_offload_seq = "text_encoder->projector->diffusion_head->autoencoder" |
| |
|
| | def __init__( |
| | self, |
| | tokenizer, |
| | text_encoder, |
| | autoencoder, |
| | diffusion_head, |
| | projector, |
| | supported_image_sizes: Optional[List[List[int]]] = None, |
| | dtype: Optional[torch.dtype] = None, |
| | ) -> None: |
| | super().__init__() |
| | self.register_modules( |
| | tokenizer=tokenizer, |
| | text_encoder=text_encoder, |
| | autoencoder=autoencoder, |
| | diffusion_head=diffusion_head, |
| | projector=projector, |
| | ) |
| |
|
| | image_sizes = supported_image_sizes or SUPPORTED_IMAGE_SIZES |
| | self.register_to_config(supported_image_sizes=[list(size) for size in image_sizes]) |
| |
|
| | self.hidden_size = self.text_encoder.config.hidden_size |
| | self.vae_patch_size = self.autoencoder.patch_size |
| | self.parallel_num = int(self.diffusion_head.config.parallel_num) |
| | self.ps = int(self.parallel_num**0.5) |
| | if self.ps * self.ps != self.parallel_num: |
| | raise ValueError( |
| | f"parallel_num must be a perfect square (got {self.parallel_num})." |
| | ) |
| |
|
| | self._build_pos_embed() |
| |
|
| | @property |
| | def supported_image_sizes(self) -> List[List[int]]: |
| | return [list(size) for size in self.config.supported_image_sizes] |
| |
|
| | def _execution_device_fallback(self) -> torch.device: |
| | if getattr(self, "_execution_device", None) is not None: |
| | return self._execution_device |
| | return next(self.text_encoder.parameters()).device |
| |
|
| | def _build_pos_embed(self) -> None: |
| | max_resolution = max(max(size) for size in self.supported_image_sizes) |
| | max_len = max_resolution // self.vae_patch_size |
| | pos_embed_1d = self._get_1d_sincos_pos_embed(self.hidden_size // 2, max_len) |
| | self.pos_embed_1d = pos_embed_1d |
| |
|
| | @staticmethod |
| | def _get_1d_sincos_pos_embed(dim: int, max_len: int, pe_interpolation: float = 1.0) -> torch.Tensor: |
| | if dim % 2 != 0: |
| | raise ValueError(f"dim must be even, got {dim}") |
| | omega = torch.arange(dim // 2, dtype=torch.float32) |
| | omega /= dim / 2.0 |
| | omega = 1.0 / 10000**omega |
| | pos = torch.arange(max_len, dtype=torch.float32) / pe_interpolation |
| | out = torch.einsum("m,d->md", pos, omega) |
| | emb_sin = torch.sin(out) |
| | emb_cos = torch.cos(out) |
| | return torch.cat([emb_sin, emb_cos], dim=1) |
| |
|
| | def _get_2d_embed(self, h: int, w: int, ps: int = 1) -> torch.Tensor: |
| | emb_v = self.pos_embed_1d[:h] |
| | emb_h = self.pos_embed_1d[:w] |
| | grid_v = emb_v.view(h, 1, self.hidden_size // 2).repeat(1, w, 1) |
| | grid_h = emb_h.view(1, w, self.hidden_size // 2).repeat(h, 1, 1) |
| | pos_embed = torch.cat([grid_h, grid_v], dim=-1) |
| | return rearrange(pos_embed, "(h p1) (w p2) c -> (h w p1 p2) c", p1=ps, p2=ps) |
| |
|
| | def _encode_prompt_to_embeds( |
| | self, |
| | prompt: str, |
| | image_size: Tuple[int, int], |
| | num_images_per_prompt: int, |
| | guidance_scale: float, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: |
| | device = self._execution_device_fallback() |
| | model = self.text_encoder.model |
| | tokenizer = self.tokenizer |
| |
|
| | cond_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" |
| | uncond_prompt = "<|im_start|>assistant\n" |
| |
|
| | cond_ids = torch.tensor(tokenizer.encode(cond_prompt), device=device, dtype=torch.long) |
| | cond_emb = model.embed_tokens(cond_ids) |
| | uncond_emb = None |
| | if guidance_scale > 1.0: |
| | uncond_ids = torch.tensor(tokenizer.encode(uncond_prompt), device=device, dtype=torch.long) |
| | uncond_emb = model.embed_tokens(uncond_ids) |
| |
|
| | image_h, image_w = image_size |
| | img_start_id = tokenizer.convert_tokens_to_ids("<|vision_start|>") |
| | res_h_token_id = tokenizer.convert_tokens_to_ids(f"<|res_{image_h // self.vae_patch_size}|>") |
| | res_w_token_id = tokenizer.convert_tokens_to_ids(f"<|res_{image_w // self.vae_patch_size}|>") |
| | img_start_emb = model.embed_tokens(torch.tensor([img_start_id, res_h_token_id, res_w_token_id], device=device)) |
| |
|
| | for i in range(1, self.parallel_num): |
| | query_token_id = tokenizer.convert_tokens_to_ids(f"<|query_{i}|>") |
| | query_token = torch.tensor([query_token_id], device=device, dtype=torch.long) |
| | query_embed = model.embed_tokens(query_token) |
| | img_start_emb = torch.cat([img_start_emb, query_embed], dim=0) |
| |
|
| | input_embeds_cond = torch.cat([cond_emb, img_start_emb], dim=0).unsqueeze(0).repeat(num_images_per_prompt, 1, 1) |
| | input_embeds_uncond = None |
| | if guidance_scale > 1.0 and uncond_emb is not None: |
| | input_embeds_uncond = torch.cat([uncond_emb, img_start_emb], dim=0).unsqueeze(0).repeat(num_images_per_prompt, 1, 1) |
| | return input_embeds_cond, input_embeds_uncond, img_start_emb |
| |
|
| | def _decode_tokens_to_image(self, image_latents: torch.Tensor, image_size: Tuple[int, int], ps: int = 1) -> torch.Tensor: |
| | h, w = image_size |
| | image_latents = rearrange(image_latents, "b (h w p1 p2) c -> b c (h p1) (w p2)", h=h // ps, w=w // ps, p1=ps, p2=ps) |
| | return self.autoencoder.decode(image_latents) |
| |
|
| | @torch.no_grad() |
| | def _generate_single_prompt( |
| | self, |
| | prompt: str, |
| | height: int, |
| | width: int, |
| | num_inference_steps: int, |
| | guidance_scale: float, |
| | num_images_per_prompt: int, |
| | generator: Optional[torch.Generator], |
| | show_progress_bar: bool, |
| | ) -> torch.Tensor: |
| | image_size = (height, width) |
| | if list(image_size) not in self.supported_image_sizes: |
| | raise ValueError( |
| | f"image_size {list(image_size)} is not supported. " |
| | f"Please choose from {self.supported_image_sizes}" |
| | ) |
| |
|
| | h, w = height // self.vae_patch_size, width // self.vae_patch_size |
| | max_length = h * w |
| | step_width = self.parallel_num |
| | if max_length % step_width != 0: |
| | raise ValueError( |
| | f"max_length ({max_length}) must be divisible by parallel_num ({step_width})." |
| | ) |
| | num_steps = max_length // step_width |
| |
|
| | device = self._execution_device_fallback() |
| | model = self.text_encoder.model |
| | dtype = next(self.text_encoder.parameters()).dtype |
| |
|
| | input_embeds_cond, input_embeds_uncond, _ = self._encode_prompt_to_embeds( |
| | prompt=prompt, |
| | image_size=image_size, |
| | num_images_per_prompt=num_images_per_prompt, |
| | guidance_scale=guidance_scale, |
| | ) |
| | pos_embed_for_diff = self._get_2d_embed(h, w, ps=self.ps).unsqueeze(0).to(device=device, dtype=dtype) |
| |
|
| | autocast_ctx = ( |
| | torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16) |
| | if device.type == "cuda" |
| | else nullcontext() |
| | ) |
| |
|
| | with autocast_ctx: |
| | outputs_c = model(inputs_embeds=input_embeds_cond[:, :-step_width, :], use_cache=True) |
| | pkv_c = outputs_c.past_key_values |
| |
|
| | bi_attn_mask = torch.ones( |
| | (input_embeds_cond.shape[0], 1, step_width, step_width + _get_pkv_seq_len(pkv_c)), |
| | dtype=torch.bool, |
| | device=device, |
| | ) |
| | outputs_c = model( |
| | inputs_embeds=input_embeds_cond[:, -step_width:, :], |
| | past_key_values=pkv_c, |
| | use_cache=True, |
| | attention_mask=bi_attn_mask, |
| | ) |
| | pkv_c = outputs_c.past_key_values |
| | hidden_c = outputs_c.last_hidden_state[:, -step_width:] |
| |
|
| | hidden_u = None |
| | pkv_u = None |
| | if guidance_scale > 1.0 and input_embeds_uncond is not None: |
| | outputs_u = model(inputs_embeds=input_embeds_uncond[:, :-step_width, :], use_cache=True) |
| | pkv_u = outputs_u.past_key_values |
| | bi_attn_mask_u = torch.ones( |
| | (input_embeds_uncond.shape[0], 1, step_width, step_width + _get_pkv_seq_len(pkv_u)), |
| | dtype=torch.bool, |
| | device=device, |
| | ) |
| | outputs_u = model( |
| | inputs_embeds=input_embeds_uncond[:, -step_width:, :], |
| | past_key_values=pkv_u, |
| | use_cache=True, |
| | attention_mask=bi_attn_mask_u, |
| | ) |
| | pkv_u = outputs_u.past_key_values |
| | hidden_u = outputs_u.last_hidden_state[:, -step_width:] |
| |
|
| | out_tokens = [] |
| | step_iter = range(num_steps) |
| | if show_progress_bar: |
| | step_iter = tqdm(step_iter, total=num_steps, desc="Decoding steps") |
| |
|
| | for step in step_iter: |
| | if guidance_scale > 1.0 and hidden_u is not None: |
| | h_fused = torch.cat([hidden_c, hidden_u], dim=0) |
| | else: |
| | h_fused = hidden_c |
| |
|
| | pos_slice = pos_embed_for_diff[:, step * step_width : (step + 1) * step_width, :] |
| | h_fused = h_fused + pos_slice |
| | pred_latents = self.diffusion_head.sample( |
| | h_fused, |
| | num_sampling_steps=num_inference_steps, |
| | cfg=guidance_scale, |
| | generator=generator, |
| | ) |
| | curr_tokens = torch.sign(pred_latents) |
| | curr_embeds = self.projector(curr_tokens) |
| | out_tokens.append(curr_tokens[:num_images_per_prompt]) |
| |
|
| | model_input = curr_embeds + pos_slice |
| | bi_attn_mask = torch.ones( |
| | (model_input.shape[0], 1, model_input.shape[1], model_input.shape[1] + _get_pkv_seq_len(pkv_c)), |
| | dtype=torch.bool, |
| | device=device, |
| | ) |
| | outputs_c = model( |
| | inputs_embeds=model_input[:num_images_per_prompt], |
| | past_key_values=pkv_c, |
| | use_cache=True, |
| | attention_mask=bi_attn_mask[:num_images_per_prompt], |
| | ) |
| | pkv_c = outputs_c.past_key_values |
| | hidden_c = outputs_c.last_hidden_state[:, -step_width:] |
| |
|
| | if guidance_scale > 1.0 and hidden_u is not None and pkv_u is not None: |
| | bi_attn_mask_u = torch.ones( |
| | (model_input.shape[0], 1, model_input.shape[1], model_input.shape[1] + _get_pkv_seq_len(pkv_u)), |
| | dtype=torch.bool, |
| | device=device, |
| | ) |
| | outputs_u = model( |
| | inputs_embeds=model_input[num_images_per_prompt:], |
| | past_key_values=pkv_u, |
| | use_cache=True, |
| | attention_mask=bi_attn_mask_u[num_images_per_prompt:], |
| | ) |
| | pkv_u = outputs_u.past_key_values |
| | hidden_u = outputs_u.last_hidden_state[:, -step_width:] |
| |
|
| | full_output = torch.cat(out_tokens, dim=1) |
| | return self._decode_tokens_to_image(full_output, image_size=(h, w), ps=self.ps) |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | prompt: PromptType, |
| | height: int = 1024, |
| | width: int = 1024, |
| | num_inference_steps: int = 50, |
| | guidance_scale: float = 7.5, |
| | num_images_per_prompt: int = 1, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | output_type: str = "pil", |
| | return_dict: bool = True, |
| | show_progress_bar: bool = False, |
| | ) -> Union[ImagePipelineOutput, Tuple]: |
| | prompts = [prompt] if isinstance(prompt, str) else list(prompt) |
| | if len(prompts) == 0: |
| | raise ValueError("prompt must be a non-empty string or list of strings.") |
| |
|
| | if isinstance(generator, list) and len(generator) != len(prompts): |
| | raise ValueError("When passing a list of generators, its length must equal len(prompt).") |
| |
|
| | image_tensors = [] |
| | for i, prompt_text in enumerate(prompts): |
| | prompt_generator = generator[i] if isinstance(generator, list) else generator |
| | images = self._generate_single_prompt( |
| | prompt=prompt_text, |
| | height=height, |
| | width=width, |
| | num_inference_steps=num_inference_steps, |
| | guidance_scale=guidance_scale, |
| | num_images_per_prompt=num_images_per_prompt, |
| | generator=prompt_generator, |
| | show_progress_bar=show_progress_bar, |
| | ) |
| | image_tensors.append(images) |
| |
|
| | images_pt = torch.cat(image_tensors, dim=0) |
| | images_pt_01 = torch.clamp((images_pt + 1.0) / 2.0, 0.0, 1.0) |
| |
|
| | if output_type == "pt": |
| | output_images = images_pt_01 |
| | elif output_type == "np": |
| | output_images = images_pt_01.permute(0, 2, 3, 1).float().cpu().numpy() |
| | elif output_type == "pil": |
| | images_uint8 = ( |
| | torch.clamp(127.5 * images_pt + 128.0, 0, 255) |
| | .permute(0, 2, 3, 1) |
| | .to("cpu", dtype=torch.uint8) |
| | .numpy() |
| | ) |
| | output_images = [Image.fromarray(image) for image in images_uint8] |
| | else: |
| | raise ValueError(f"Unsupported output_type={output_type}. Expected 'pil', 'np', or 'pt'.") |
| |
|
| | if not return_dict: |
| | return (output_images,) |
| | return ImagePipelineOutput(images=output_images) |
| |
|
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | prompt: str, |
| | height: int = 1024, |
| | width: int = 1024, |
| | num_sampling_steps: int = 50, |
| | guidance_scale: float = 7.5, |
| | num_images: int = 1, |
| | seed: Optional[int] = None, |
| | ) -> List[Image.Image]: |
| | generator = None |
| | if seed is not None: |
| | device = self._execution_device_fallback() |
| | generator_device = "cuda" if device.type == "cuda" else "cpu" |
| | generator = torch.Generator(device=generator_device).manual_seed(seed) |
| | output = self( |
| | prompt=prompt, |
| | height=height, |
| | width=width, |
| | num_inference_steps=num_sampling_steps, |
| | guidance_scale=guidance_scale, |
| | num_images_per_prompt=num_images, |
| | generator=generator, |
| | output_type="pil", |
| | return_dict=True, |
| | show_progress_bar=True, |
| | ) |
| | return output.images |
| |
|