File size: 28,577 Bytes
97bc03d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 |
import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, Lumina2Pipeline
from transformers import AutoTokenizer, Gemma2Model
import copy
import torch.nn as nn
import torch.nn.functional as F
from diffusers.training_utils import (
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
)
from diffusers.pipelines.lumina2.pipeline_lumina2 import *
class Lumina2Decoder(torch.nn.Module):
def __init__(self, config, args):
super().__init__()
self.diffusion_model_path = config.model_path
if not hasattr(args, "revision"):
args.revision = None
if not hasattr(args, "variant"):
args.variant = None
self.tokenizer_one = AutoTokenizer.from_pretrained(
self.diffusion_model_path,
subfolder="tokenizer",
revision=args.revision,
)
self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
self.diffusion_model_path, subfolder="scheduler"
)
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
self.text_encoder_one = Gemma2Model.from_pretrained(
self.diffusion_model_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
self.text_encoding_pipeline = Lumina2Pipeline.from_pretrained(
self.diffusion_model_path,
vae=None,
transformer=None,
text_encoder=self.text_encoder_one,
tokenizer=self.tokenizer_one,
)
self.vae = AutoencoderKL.from_pretrained(
self.diffusion_model_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
)
if args.ori_inp_dit=="seq":
from star.models.pixel_decoder.transformer_lumina2_seq import Lumina2Transformer2DModel
elif args.ori_inp_dit=="ref":
from star.models.pixel_decoder.transformer_lumina2 import Lumina2Transformer2DModel
self.transformer = Lumina2Transformer2DModel.from_pretrained(
self.diffusion_model_path, subfolder="transformer", revision=args.revision, variant=args.variant
)
vq_dim = 512
patch_size = self.transformer.config.patch_size
in_channels = vq_dim + self.transformer.config.in_channels # 48 for mask
out_channels = self.transformer.x_embedder.out_features
load_num_channel = self.transformer.config.in_channels * patch_size * patch_size
self.transformer.register_to_config(in_channels=in_channels)
transformer = self.transformer
with torch.no_grad():
new_proj = nn.Linear(
in_channels * patch_size * patch_size, out_channels, bias=True
)
new_proj.weight.zero_()
new_proj = new_proj.to(transformer.x_embedder.weight.dtype)
new_proj.weight[:, :load_num_channel].copy_(transformer.x_embedder.weight)
new_proj.bias.copy_(transformer.x_embedder.bias)
transformer.x_embedder = new_proj
self.ori_inp_dit = args.ori_inp_dit
if args.ori_inp_dit=="seq":
refiner_channels = transformer.noise_refiner[-1].dim
with torch.no_grad():
vae2cond_proj1 = nn.Linear(refiner_channels, refiner_channels, bias=True)
vae2cond_act = nn.GELU(approximate='tanh')
vae2cond_proj2 = nn.Linear(refiner_channels, refiner_channels, bias=False)
vae2cond_proj2.weight.zero_()
ori_inp_refiner = nn.Sequential(
vae2cond_proj1,
vae2cond_act,
vae2cond_proj2
)
transformer.ori_inp_refiner = ori_inp_refiner
transformer.ori_inp_dit = self.ori_inp_dit
elif args.ori_inp_dit=="ref":
transformer.initialize_ref_weights()
transformer.ori_inp_dit = self.ori_inp_dit
transformer.requires_grad_(True)
if args.grad_ckpt and args.diffusion_resolution==1024:
transformer.gradient_checkpointing = args.grad_ckpt
transformer.enable_gradient_checkpointing()
self.vae.requires_grad_(False)
self.vae.to(dtype=torch.float32)
self.args = args
self.pipe = Lumina2InstructPix2PixPipeline.from_pretrained(self.diffusion_model_path,
transformer=transformer,
text_encoder=self.text_encoder_one,
vae=self.vae,
torch_dtype=torch.bfloat16)
with torch.no_grad():
_, _, self.uncond_prompt_embeds, self.uncond_prompt_attention_mask = self.text_encoding_pipeline.encode_prompt(
"",
max_sequence_length=self.args.max_diff_seq_length,
)
def compute_text_embeddings(self,prompt, text_encoding_pipeline):
with torch.no_grad():
prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(
prompt,
max_sequence_length=self.args.max_diff_seq_length,
)
return prompt_embeds, prompt_attention_mask
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
sigmas = self.noise_scheduler_copy.sigmas.to(dtype=dtype)
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device=timesteps.device)
timesteps = timesteps
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def forward(self, batch_gpu,batch, image_embeds):
args = self.args
pixel_values = batch_gpu["full_pixel_values"].to(dtype=self.vae.dtype) #aux_image
data_type = "t2i"
if len(pixel_values.shape)==5:
bs,num_img,c,h,w = pixel_values.shape
if num_img==2:
data_type = "edit"
pixel_values_ori_img = pixel_values[:,0]
pixel_values = pixel_values[:,-1]
pixel_values = F.interpolate(pixel_values, size=(self.args.diffusion_resolution, self.args.diffusion_resolution), mode='bilinear',align_corners=False)
if data_type=="edit" and self.ori_inp_dit!="none":
pixel_values_ori_img = F.interpolate(pixel_values_ori_img, size=(self.args.diffusion_resolution, self.args.diffusion_resolution), mode='bilinear', align_corners=False)
prompt = batch["prompts"]
bs,_,_,_ = pixel_values.shape
image_prompt_embeds = None
image_embeds_2d = image_embeds.reshape(bs, 24, 24, image_embeds.shape[-1]).permute(0, 3, 1, 2)
image_embeds_2d = F.interpolate(image_embeds_2d, size=(args.diffusion_resolution//8, args.diffusion_resolution//8), mode='bilinear', align_corners=False)
control_emd = args.control_emd
prompt_embeds, prompt_attention_mask = self.compute_text_embeddings(prompt, self.text_encoding_pipeline)
if control_emd=="mix":
prompt_embeds=torch.cat([prompt_embeds, image_prompt_embeds], dim=1) #use mix
elif control_emd=="null":
prompt_embeds = torch.zeros_like(prompt_embeds)
prompt_attention_mask = torch.ones_like(prompt_attention_mask)
elif control_emd=="text":
pass
elif control_emd=="vit" or control_emd=="vq" or control_emd=="vqvae" or control_emd=="vqconcat" or control_emd=="vqconcatvit":
prompt_embeds=image_prompt_embeds
latents = self.vae.encode(pixel_values).latent_dist.sample()
latents = (latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
latents = latents.to(dtype=image_embeds.dtype)
latents_ori_img = torch.zeros_like(latents)
if data_type=="edit" and self.ori_inp_dit!="none":
latents_ori_img = self.vae.encode(pixel_values_ori_img).latent_dist.sample()
latents_ori_img = (latents_ori_img - self.vae.config.shift_factor) * self.vae.config.scaling_factor
latents_ori_img = latents_ori_img.to(dtype=image_embeds.dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype).to(device=noise.device)
#noisy_model_input = (1.0 - sigmas) * noise + sigmas * latents
noisy_model_input = sigmas * noise + (1-sigmas) * latents
#noisy_model_input + (1-sigmas)*(latents - noise) = latents
# Get the additional image embedding for conditioning.
# Instead of getting a diagonal Gaussian here, we simply take the mode.
original_image_embeds = image_embeds_2d
if args.conditioning_dropout_prob is not None:
random_p = torch.rand(bsz, device=latents.device)
# Sample masks for the edit prompts.
prompt_mask = random_p < 2 * args.uncondition_prob
prompt_mask = prompt_mask.reshape(bsz, 1, 1)
# Final text conditioning.
#prompt_embeds = torch.where(prompt_mask, torch.zeros_like(prompt_embeds), prompt_embeds)
prompt_embeds = torch.where(prompt_mask, self.uncond_prompt_embeds.repeat(prompt_embeds.shape[0],1,1).to(prompt_embeds.device), prompt_embeds)
prompt_attention_mask = torch.where(prompt_mask[:,0], self.uncond_prompt_attention_mask.repeat(prompt_embeds.shape[0],1).to(prompt_embeds.device), prompt_attention_mask)
# Sample masks for the original images.
#random_p_vq = torch.rand(bsz, device=latents.device)
image_mask_dtype = original_image_embeds.dtype
image_mask = 1 - (
(random_p <= args.conditioning_dropout_prob).to(image_mask_dtype)
)
image_mask = image_mask.reshape(bsz, 1, 1, 1)
if data_type=="edit":
image_mask=0
# Final image conditioning.
original_image_embeds = image_mask * original_image_embeds
ori_latent_mask = 1 - (
(random_p >= args.uncondition_prob).to(image_mask_dtype)
* (random_p < 3 * args.uncondition_prob).to(image_mask_dtype)
)
ori_latent_mask = ori_latent_mask.reshape(bsz, 1, 1, 1)
latents_ori_img = ori_latent_mask * latents_ori_img
concatenated_noisy_latents = torch.cat([noisy_model_input, original_image_embeds], dim=1)
ref_image_hidden_states = None
if self.ori_inp_dit=="dim":
concatenated_noisy_latents = torch.cat([concatenated_noisy_latents, latents_ori_img], dim=1)
elif self.ori_inp_dit=="seq":
latents_ori_img = torch.cat([latents_ori_img, original_image_embeds], dim=1)
concatenated_noisy_latents = torch.cat([concatenated_noisy_latents, latents_ori_img], dim=2)
elif self.ori_inp_dit=="ref":
latents_ori_img = torch.cat([latents_ori_img, original_image_embeds], dim=1)
ref_image_hidden_states = latents_ori_img[:,None]
# Predict the noise residual
# scale the timesteps (reversal not needed as we used a reverse lerp above already)
timesteps = 1-timesteps / self.noise_scheduler.config.num_train_timesteps #timesteps / self.noise_scheduler.config.num_train_timesteps
model_pred = self.transformer(
hidden_states=concatenated_noisy_latents,
timestep=timesteps,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
# ref_image_hidden_states = ref_image_hidden_states,
return_dict=False,
)[0]
if self.ori_inp_dit=="seq":
model_pred = model_pred[:, :, :args.diffusion_resolution//8, :]
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
target = latents - noise
# Conditioning dropout to support classifier-free guidance during inference. For more details
# check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
# Concatenate the `original_image_embeds` with the `noisy_latents`.
# Get the target for loss depending on the prediction type
loss = torch.mean(
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1,
)
loss = loss.mean()
loss_value = loss.item()
return loss
class Lumina2InstructPix2PixPipeline(Lumina2Pipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
num_inference_steps: int = 30,
guidance_scale: float = 4.0,
negative_prompt: Union[str, List[str]] = None,
sigmas: List[float] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
system_prompt: Optional[str] = None,
cfg_trunc_ratio=[0.0,1.0],
cfg_normalization: bool = False,
max_sequence_length: int = 256,
control_emd="text",
img_cfg_trunc_ratio =[0.0,1.0],
gen_image_embeds=None,only_t2i="vqconcat",image_prompt_embeds=None,ori_inp_img=None,img_guidance_scale=1.5,vq_guidance_scale=0,ori_inp_way="none",
) -> Union[ImagePipelineOutput, Tuple]:
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
num_images_per_prompt = gen_image_embeds.shape[0] if gen_image_embeds is not None else image_prompt_embeds.shape[0]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
system_prompt=system_prompt,
)
if gen_image_embeds is not None:
image_embeds_8=gen_image_embeds
if control_emd=="text":
pass
elif control_emd=="null":
prompt_embeds = torch.zeros_like(prompt_embeds)
prompt_attention_mask = torch.zeros_like(prompt_attention_mask)
negative_prompt_embeds = prompt_embeds
negative_prompt_attention_mask = prompt_attention_mask
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds,negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask,negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare latents.
latent_channels = self.vae.config.latent_channels #self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
latents_ori_img = torch.zeros_like(latents)
if ori_inp_img is not None and ori_inp_way !="none":
#fuck = torch.load(ori_inp_img).to(latents.device)
ori_inp_img = F.interpolate(ori_inp_img[None].to(latents.device,latents.dtype), size=(height,width), mode='bilinear',align_corners=False)
latents_ori_img = self.vae.encode(ori_inp_img).latent_dist.sample()
latents_ori_img = (latents_ori_img- self.vae.config.shift_factor) * self.vae.config.scaling_factor
latents_ori_img = latents_ori_img.to(dtype=latents.dtype)
if ori_inp_way !="none":
negative_latents_ori_img = torch.zeros_like(latents_ori_img).to(prompt_embeds.dtype)
latents_ori_img = torch.cat([negative_latents_ori_img,latents_ori_img, latents_ori_img], dim=0) if self.do_classifier_free_guidance else latents_ori_img
vq_in_edit = False
if only_t2i==True:
image_latents = torch.zeros_like(latents)[:,:8]
elif only_t2i=="vqconcat":
image_embeds_2d = image_embeds_8.reshape(batch_size* num_images_per_prompt,24,24,image_embeds_8.shape[-1]).permute(0,3,1,2)
if ori_inp_img is not None and image_embeds_8.mean()!=0:
vq_in_edit = True
image_vq_latents = F.interpolate(image_embeds_2d, size=(height//8,width//8), mode='bilinear',align_corners=False).to(latents.device,latents.dtype)
image_latents = torch.zeros_like(image_vq_latents)
else:
image_latents = F.interpolate(image_embeds_2d, size=(height//8,width//8), mode='bilinear',align_corners=False).to(latents.device,latents.dtype)
negative_image_latents = torch.zeros_like(image_latents).to(prompt_embeds.dtype)
image_latents = torch.cat([negative_image_latents,image_latents, image_latents], dim=0) if self.do_classifier_free_guidance else image_latents
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
self.scheduler.sigmas=self.scheduler.sigmas.to(latents.dtype) #hjc find bug
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# compute whether apply classifier-free truncation on this timestep
do_classifier_free_truncation = not ((i + 1) / num_inference_steps > cfg_trunc_ratio[0] and (i + 1) / num_inference_steps < cfg_trunc_ratio[1])
img_do_classifier_free_truncation = not ((i + 1) / num_inference_steps > img_cfg_trunc_ratio[0] and (i + 1) / num_inference_steps < img_cfg_trunc_ratio[1])
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
current_timestep = 1 - t / self.scheduler.config.num_train_timesteps
latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
ref_image_hidden_states = None
if ori_inp_way=="seq":
latents_ori_img_cat = torch.cat([latents_ori_img, image_latents], dim=1)
latent_model_input = torch.cat([latent_model_input, latents_ori_img_cat], dim=2)
elif ori_inp_way=="ref":
latents_ori_img_cat = torch.cat([latents_ori_img, image_latents], dim=1)
ref_image_hidden_states = latents_ori_img_cat[:,None]
if ori_inp_way=="ref":
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=current_timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
return_dict=False,ref_image_hidden_states=ref_image_hidden_states,
attention_kwargs=self.attention_kwargs,
)[0]
else:
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=current_timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
return_dict=False,
attention_kwargs=self.attention_kwargs,
)[0]
if ori_inp_way=="seq":
noise_pred = noise_pred[:,:,:height//8,:]
if vq_in_edit:
latent_model_vq_input = torch.cat([latents, image_vq_latents], dim=1)
if ori_inp_way=="seq":
latents_ori_img_cat_vq = torch.cat([torch.zeros_like(latents), image_vq_latents], dim=1)
latent_model_vq_input = torch.cat([latent_model_vq_input, latents_ori_img_cat_vq], dim=2)
noise_vq_pred = self.transformer(
hidden_states=latent_model_vq_input,
timestep=current_timestep[-1:],
encoder_hidden_states=prompt_embeds[-1:],
encoder_attention_mask=prompt_attention_mask[-1:],
return_dict=False,
attention_kwargs=self.attention_kwargs,
)[0]
if ori_inp_way=="seq":
noise_vq_pred = noise_vq_pred[:,:,:height//8,:]
# perform normalization-based guidance scale on a truncated timestep interval
if self.do_classifier_free_guidance:
noise_pred_uncond,noise_pred_img, noise_pred_text = noise_pred.chunk(3)
if not do_classifier_free_truncation and not img_do_classifier_free_truncation:
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_img)+ img_guidance_scale * (noise_pred_img - noise_pred_uncond)
elif not do_classifier_free_truncation and img_do_classifier_free_truncation:
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_img)+ 1 * (noise_pred_img - noise_pred_uncond)
elif do_classifier_free_truncation and not img_do_classifier_free_truncation:
noise_pred = noise_pred_uncond + 1 * (noise_pred_text - noise_pred_img)+ img_guidance_scale * (noise_pred_img - noise_pred_uncond)
else:
noise_pred = noise_pred_text
if vq_in_edit:
noise_pred = noise_pred +vq_guidance_scale*(noise_vq_pred-noise_pred_uncond)
# apply normalization after classifier-free guidance
if cfg_normalization:
cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True)
noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_pred = noise_pred * (cond_norm / noise_norm)
else:
noise_pred = noise_pred
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
noise_pred = -noise_pred
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
else:
image = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
|