BiliSakura commited on
Commit
2f5687a
·
verified ·
1 Parent(s): 1b8b5f7

Update all files for BitDance-14B-16x-diffusers

Browse files
bitdance_diffusers/pipeline_bitdance.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from contextlib import nullcontext
4
+ from typing import List, Optional, Sequence, Tuple, Union
5
+
6
+ import torch
7
+ from einops import rearrange
8
+ from PIL import Image
9
+ from tqdm.auto import tqdm
10
+
11
+ from diffusers import DiffusionPipeline
12
+ from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
13
+
14
+ from .constants import SUPPORTED_IMAGE_SIZES
15
+
16
+
17
+ PromptType = Union[str, List[str]]
18
+
19
+
20
+ class BitDanceDiffusionPipeline(DiffusionPipeline):
21
+ model_cpu_offload_seq = "text_encoder->projector->diffusion_head->autoencoder"
22
+
23
+ def __init__(
24
+ self,
25
+ tokenizer,
26
+ text_encoder,
27
+ autoencoder,
28
+ diffusion_head,
29
+ projector,
30
+ supported_image_sizes: Optional[Sequence[Sequence[int]]] = None,
31
+ ) -> None:
32
+ super().__init__()
33
+ self.register_modules(
34
+ tokenizer=tokenizer,
35
+ text_encoder=text_encoder,
36
+ autoencoder=autoencoder,
37
+ diffusion_head=diffusion_head,
38
+ projector=projector,
39
+ )
40
+
41
+ image_sizes = supported_image_sizes or SUPPORTED_IMAGE_SIZES
42
+ self.register_to_config(supported_image_sizes=[list(size) for size in image_sizes])
43
+
44
+ self.hidden_size = self.text_encoder.config.hidden_size
45
+ self.vae_patch_size = self.autoencoder.patch_size
46
+ self.parallel_num = int(self.diffusion_head.config.parallel_num)
47
+ self.ps = int(self.parallel_num**0.5)
48
+ if self.ps * self.ps != self.parallel_num:
49
+ raise ValueError(
50
+ f"parallel_num must be a perfect square (got {self.parallel_num})."
51
+ )
52
+
53
+ self._build_pos_embed()
54
+
55
+ @property
56
+ def supported_image_sizes(self) -> List[List[int]]:
57
+ return [list(size) for size in self.config.supported_image_sizes]
58
+
59
+ def _execution_device_fallback(self) -> torch.device:
60
+ if getattr(self, "_execution_device", None) is not None:
61
+ return self._execution_device
62
+ return next(self.text_encoder.parameters()).device
63
+
64
+ def _build_pos_embed(self) -> None:
65
+ max_resolution = max(max(size) for size in self.supported_image_sizes)
66
+ max_len = max_resolution // self.vae_patch_size
67
+ pos_embed_1d = self._get_1d_sincos_pos_embed(self.hidden_size // 2, max_len)
68
+ self.pos_embed_1d = pos_embed_1d
69
+
70
+ @staticmethod
71
+ def _get_1d_sincos_pos_embed(dim: int, max_len: int, pe_interpolation: float = 1.0) -> torch.Tensor:
72
+ if dim % 2 != 0:
73
+ raise ValueError(f"dim must be even, got {dim}")
74
+ omega = torch.arange(dim // 2, dtype=torch.float32)
75
+ omega /= dim / 2.0
76
+ omega = 1.0 / 10000**omega
77
+ pos = torch.arange(max_len, dtype=torch.float32) / pe_interpolation
78
+ out = torch.einsum("m,d->md", pos, omega)
79
+ emb_sin = torch.sin(out)
80
+ emb_cos = torch.cos(out)
81
+ return torch.cat([emb_sin, emb_cos], dim=1)
82
+
83
+ def _get_2d_embed(self, h: int, w: int, ps: int = 1) -> torch.Tensor:
84
+ emb_v = self.pos_embed_1d[:h]
85
+ emb_h = self.pos_embed_1d[:w]
86
+ grid_v = emb_v.view(h, 1, self.hidden_size // 2).repeat(1, w, 1)
87
+ grid_h = emb_h.view(1, w, self.hidden_size // 2).repeat(h, 1, 1)
88
+ pos_embed = torch.cat([grid_h, grid_v], dim=-1)
89
+ return rearrange(pos_embed, "(h p1) (w p2) c -> (h w p1 p2) c", p1=ps, p2=ps)
90
+
91
+ def _encode_prompt_to_embeds(
92
+ self,
93
+ prompt: str,
94
+ image_size: Tuple[int, int],
95
+ num_images_per_prompt: int,
96
+ guidance_scale: float,
97
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
98
+ device = self._execution_device_fallback()
99
+ model = self.text_encoder.model
100
+ tokenizer = self.tokenizer
101
+
102
+ cond_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
103
+ uncond_prompt = "<|im_start|>assistant\n"
104
+
105
+ cond_ids = torch.tensor(tokenizer.encode(cond_prompt), device=device, dtype=torch.long)
106
+ cond_emb = model.embed_tokens(cond_ids)
107
+ uncond_emb = None
108
+ if guidance_scale > 1.0:
109
+ uncond_ids = torch.tensor(tokenizer.encode(uncond_prompt), device=device, dtype=torch.long)
110
+ uncond_emb = model.embed_tokens(uncond_ids)
111
+
112
+ image_h, image_w = image_size
113
+ img_start_id = tokenizer.convert_tokens_to_ids("<|vision_start|>")
114
+ res_h_token_id = tokenizer.convert_tokens_to_ids(f"<|res_{image_h // self.vae_patch_size}|>")
115
+ res_w_token_id = tokenizer.convert_tokens_to_ids(f"<|res_{image_w // self.vae_patch_size}|>")
116
+ img_start_emb = model.embed_tokens(torch.tensor([img_start_id, res_h_token_id, res_w_token_id], device=device))
117
+
118
+ for i in range(1, self.parallel_num):
119
+ query_token_id = tokenizer.convert_tokens_to_ids(f"<|query_{i}|>")
120
+ query_token = torch.tensor([query_token_id], device=device, dtype=torch.long)
121
+ query_embed = model.embed_tokens(query_token)
122
+ img_start_emb = torch.cat([img_start_emb, query_embed], dim=0)
123
+
124
+ input_embeds_cond = torch.cat([cond_emb, img_start_emb], dim=0).unsqueeze(0).repeat(num_images_per_prompt, 1, 1)
125
+ input_embeds_uncond = None
126
+ if guidance_scale > 1.0 and uncond_emb is not None:
127
+ input_embeds_uncond = torch.cat([uncond_emb, img_start_emb], dim=0).unsqueeze(0).repeat(num_images_per_prompt, 1, 1)
128
+ return input_embeds_cond, input_embeds_uncond, img_start_emb
129
+
130
+ def _decode_tokens_to_image(self, image_latents: torch.Tensor, image_size: Tuple[int, int], ps: int = 1) -> torch.Tensor:
131
+ h, w = image_size
132
+ image_latents = rearrange(image_latents, "b (h w p1 p2) c -> b c (h p1) (w p2)", h=h // ps, w=w // ps, p1=ps, p2=ps)
133
+ return self.autoencoder.decode(image_latents)
134
+
135
+ @torch.no_grad()
136
+ def _generate_single_prompt(
137
+ self,
138
+ prompt: str,
139
+ height: int,
140
+ width: int,
141
+ num_inference_steps: int,
142
+ guidance_scale: float,
143
+ num_images_per_prompt: int,
144
+ generator: Optional[torch.Generator],
145
+ show_progress_bar: bool,
146
+ ) -> torch.Tensor:
147
+ image_size = (height, width)
148
+ if list(image_size) not in self.supported_image_sizes:
149
+ raise ValueError(
150
+ f"image_size {list(image_size)} is not supported. "
151
+ f"Please choose from {self.supported_image_sizes}"
152
+ )
153
+
154
+ h, w = height // self.vae_patch_size, width // self.vae_patch_size
155
+ max_length = h * w
156
+ step_width = self.parallel_num
157
+ if max_length % step_width != 0:
158
+ raise ValueError(
159
+ f"max_length ({max_length}) must be divisible by parallel_num ({step_width})."
160
+ )
161
+ num_steps = max_length // step_width
162
+
163
+ device = self._execution_device_fallback()
164
+ model = self.text_encoder.model
165
+ dtype = next(self.text_encoder.parameters()).dtype
166
+
167
+ input_embeds_cond, input_embeds_uncond, _ = self._encode_prompt_to_embeds(
168
+ prompt=prompt,
169
+ image_size=image_size,
170
+ num_images_per_prompt=num_images_per_prompt,
171
+ guidance_scale=guidance_scale,
172
+ )
173
+ pos_embed_for_diff = self._get_2d_embed(h, w, ps=self.ps).unsqueeze(0).to(device=device, dtype=dtype)
174
+
175
+ autocast_ctx = (
176
+ torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16)
177
+ if device.type == "cuda"
178
+ else nullcontext()
179
+ )
180
+
181
+ with autocast_ctx:
182
+ outputs_c = model(inputs_embeds=input_embeds_cond[:, :-step_width, :], use_cache=True)
183
+ pkv_c = outputs_c.past_key_values
184
+
185
+ bi_attn_mask = torch.ones(
186
+ (input_embeds_cond.shape[0], 1, step_width, step_width + pkv_c[0][0].shape[2]),
187
+ dtype=torch.bool,
188
+ device=device,
189
+ )
190
+ outputs_c = model(
191
+ inputs_embeds=input_embeds_cond[:, -step_width:, :],
192
+ past_key_values=pkv_c,
193
+ use_cache=True,
194
+ attention_mask=bi_attn_mask,
195
+ )
196
+ pkv_c = outputs_c.past_key_values
197
+ hidden_c = outputs_c.last_hidden_state[:, -step_width:]
198
+
199
+ hidden_u = None
200
+ pkv_u = None
201
+ if guidance_scale > 1.0 and input_embeds_uncond is not None:
202
+ outputs_u = model(inputs_embeds=input_embeds_uncond[:, :-step_width, :], use_cache=True)
203
+ pkv_u = outputs_u.past_key_values
204
+ outputs_u = model(
205
+ inputs_embeds=input_embeds_uncond[:, -step_width:, :],
206
+ past_key_values=pkv_u,
207
+ use_cache=True,
208
+ attention_mask=bi_attn_mask,
209
+ )
210
+ pkv_u = outputs_u.past_key_values
211
+ hidden_u = outputs_u.last_hidden_state[:, -step_width:]
212
+
213
+ out_tokens = []
214
+ step_iter = range(num_steps)
215
+ if show_progress_bar:
216
+ step_iter = tqdm(step_iter, total=num_steps, desc="Decoding steps")
217
+
218
+ for step in step_iter:
219
+ if guidance_scale > 1.0 and hidden_u is not None:
220
+ h_fused = torch.cat([hidden_c, hidden_u], dim=0)
221
+ else:
222
+ h_fused = hidden_c
223
+
224
+ pos_slice = pos_embed_for_diff[:, step * step_width : (step + 1) * step_width, :]
225
+ h_fused = h_fused + pos_slice
226
+ pred_latents = self.diffusion_head.sample(
227
+ h_fused,
228
+ num_sampling_steps=num_inference_steps,
229
+ cfg=guidance_scale,
230
+ generator=generator,
231
+ )
232
+ curr_tokens = torch.sign(pred_latents)
233
+ curr_embeds = self.projector(curr_tokens)
234
+ out_tokens.append(curr_tokens[:num_images_per_prompt])
235
+
236
+ model_input = curr_embeds + pos_slice
237
+ bi_attn_mask = torch.ones(
238
+ (model_input.shape[0], 1, model_input.shape[1], model_input.shape[1] + pkv_c[0][0].shape[2]),
239
+ dtype=torch.bool,
240
+ device=device,
241
+ )
242
+ outputs_c = model(
243
+ inputs_embeds=model_input[:num_images_per_prompt],
244
+ past_key_values=pkv_c,
245
+ use_cache=True,
246
+ attention_mask=bi_attn_mask[:num_images_per_prompt],
247
+ )
248
+ pkv_c = outputs_c.past_key_values
249
+ hidden_c = outputs_c.last_hidden_state[:, -step_width:]
250
+
251
+ if guidance_scale > 1.0 and hidden_u is not None and pkv_u is not None:
252
+ outputs_u = model(
253
+ inputs_embeds=model_input[num_images_per_prompt:],
254
+ past_key_values=pkv_u,
255
+ use_cache=True,
256
+ attention_mask=bi_attn_mask[num_images_per_prompt:],
257
+ )
258
+ pkv_u = outputs_u.past_key_values
259
+ hidden_u = outputs_u.last_hidden_state[:, -step_width:]
260
+
261
+ full_output = torch.cat(out_tokens, dim=1)
262
+ return self._decode_tokens_to_image(full_output, image_size=(h, w), ps=self.ps)
263
+
264
+ @torch.no_grad()
265
+ def __call__(
266
+ self,
267
+ prompt: PromptType,
268
+ height: int = 1024,
269
+ width: int = 1024,
270
+ num_inference_steps: int = 50,
271
+ guidance_scale: float = 7.5,
272
+ num_images_per_prompt: int = 1,
273
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
274
+ output_type: str = "pil",
275
+ return_dict: bool = True,
276
+ show_progress_bar: bool = False,
277
+ ) -> Union[ImagePipelineOutput, Tuple]:
278
+ prompts = [prompt] if isinstance(prompt, str) else list(prompt)
279
+ if len(prompts) == 0:
280
+ raise ValueError("prompt must be a non-empty string or list of strings.")
281
+
282
+ if isinstance(generator, list) and len(generator) != len(prompts):
283
+ raise ValueError("When passing a list of generators, its length must equal len(prompt).")
284
+
285
+ image_tensors = []
286
+ for i, prompt_text in enumerate(prompts):
287
+ prompt_generator = generator[i] if isinstance(generator, list) else generator
288
+ images = self._generate_single_prompt(
289
+ prompt=prompt_text,
290
+ height=height,
291
+ width=width,
292
+ num_inference_steps=num_inference_steps,
293
+ guidance_scale=guidance_scale,
294
+ num_images_per_prompt=num_images_per_prompt,
295
+ generator=prompt_generator,
296
+ show_progress_bar=show_progress_bar,
297
+ )
298
+ image_tensors.append(images)
299
+
300
+ images_pt = torch.cat(image_tensors, dim=0)
301
+ images_pt_01 = torch.clamp((images_pt + 1.0) / 2.0, 0.0, 1.0)
302
+
303
+ if output_type == "pt":
304
+ output_images = images_pt_01
305
+ elif output_type == "np":
306
+ output_images = images_pt_01.permute(0, 2, 3, 1).float().cpu().numpy()
307
+ elif output_type == "pil":
308
+ images_uint8 = (
309
+ torch.clamp(127.5 * images_pt + 128.0, 0, 255)
310
+ .permute(0, 2, 3, 1)
311
+ .to("cpu", dtype=torch.uint8)
312
+ .numpy()
313
+ )
314
+ output_images = [Image.fromarray(image) for image in images_uint8]
315
+ else:
316
+ raise ValueError(f"Unsupported output_type={output_type}. Expected 'pil', 'np', or 'pt'.")
317
+
318
+ if not return_dict:
319
+ return (output_images,)
320
+ return ImagePipelineOutput(images=output_images)
321
+
322
+ @torch.no_grad()
323
+ def generate(
324
+ self,
325
+ prompt: str,
326
+ height: int = 1024,
327
+ width: int = 1024,
328
+ num_sampling_steps: int = 50,
329
+ guidance_scale: float = 7.5,
330
+ num_images: int = 1,
331
+ seed: Optional[int] = None,
332
+ ) -> List[Image.Image]:
333
+ generator = None
334
+ if seed is not None:
335
+ device = self._execution_device_fallback()
336
+ generator_device = "cuda" if device.type == "cuda" else "cpu"
337
+ generator = torch.Generator(device=generator_device).manual_seed(seed)
338
+ output = self(
339
+ prompt=prompt,
340
+ height=height,
341
+ width=width,
342
+ num_inference_steps=num_sampling_steps,
343
+ guidance_scale=guidance_scale,
344
+ num_images_per_prompt=num_images,
345
+ generator=generator,
346
+ output_type="pil",
347
+ return_dict=True,
348
+ show_progress_bar=True,
349
+ )
350
+ return output.images