STAR / star /models /pixel_decoder /lumina2_decoder.py
MM-MVR's picture
Upload files
97bc03d verified
import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, Lumina2Pipeline
from transformers import AutoTokenizer, Gemma2Model
import copy
import torch.nn as nn
import torch.nn.functional as F
from diffusers.training_utils import (
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
)
from diffusers.pipelines.lumina2.pipeline_lumina2 import *
class Lumina2Decoder(torch.nn.Module):
def __init__(self, config, args):
super().__init__()
self.diffusion_model_path = config.model_path
if not hasattr(args, "revision"):
args.revision = None
if not hasattr(args, "variant"):
args.variant = None
self.tokenizer_one = AutoTokenizer.from_pretrained(
self.diffusion_model_path,
subfolder="tokenizer",
revision=args.revision,
)
self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
self.diffusion_model_path, subfolder="scheduler"
)
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
self.text_encoder_one = Gemma2Model.from_pretrained(
self.diffusion_model_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
self.text_encoding_pipeline = Lumina2Pipeline.from_pretrained(
self.diffusion_model_path,
vae=None,
transformer=None,
text_encoder=self.text_encoder_one,
tokenizer=self.tokenizer_one,
)
self.vae = AutoencoderKL.from_pretrained(
self.diffusion_model_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
)
if args.ori_inp_dit=="seq":
from star.models.pixel_decoder.transformer_lumina2_seq import Lumina2Transformer2DModel
elif args.ori_inp_dit=="ref":
from star.models.pixel_decoder.transformer_lumina2 import Lumina2Transformer2DModel
self.transformer = Lumina2Transformer2DModel.from_pretrained(
self.diffusion_model_path, subfolder="transformer", revision=args.revision, variant=args.variant
)
vq_dim = 512
patch_size = self.transformer.config.patch_size
in_channels = vq_dim + self.transformer.config.in_channels # 48 for mask
out_channels = self.transformer.x_embedder.out_features
load_num_channel = self.transformer.config.in_channels * patch_size * patch_size
self.transformer.register_to_config(in_channels=in_channels)
transformer = self.transformer
with torch.no_grad():
new_proj = nn.Linear(
in_channels * patch_size * patch_size, out_channels, bias=True
)
new_proj.weight.zero_()
new_proj = new_proj.to(transformer.x_embedder.weight.dtype)
new_proj.weight[:, :load_num_channel].copy_(transformer.x_embedder.weight)
new_proj.bias.copy_(transformer.x_embedder.bias)
transformer.x_embedder = new_proj
self.ori_inp_dit = args.ori_inp_dit
if args.ori_inp_dit=="seq":
refiner_channels = transformer.noise_refiner[-1].dim
with torch.no_grad():
vae2cond_proj1 = nn.Linear(refiner_channels, refiner_channels, bias=True)
vae2cond_act = nn.GELU(approximate='tanh')
vae2cond_proj2 = nn.Linear(refiner_channels, refiner_channels, bias=False)
vae2cond_proj2.weight.zero_()
ori_inp_refiner = nn.Sequential(
vae2cond_proj1,
vae2cond_act,
vae2cond_proj2
)
transformer.ori_inp_refiner = ori_inp_refiner
transformer.ori_inp_dit = self.ori_inp_dit
elif args.ori_inp_dit=="ref":
transformer.initialize_ref_weights()
transformer.ori_inp_dit = self.ori_inp_dit
transformer.requires_grad_(True)
if args.grad_ckpt and args.diffusion_resolution==1024:
transformer.gradient_checkpointing = args.grad_ckpt
transformer.enable_gradient_checkpointing()
self.vae.requires_grad_(False)
self.vae.to(dtype=torch.float32)
self.args = args
self.pipe = Lumina2InstructPix2PixPipeline.from_pretrained(self.diffusion_model_path,
transformer=transformer,
text_encoder=self.text_encoder_one,
vae=self.vae,
torch_dtype=torch.bfloat16)
with torch.no_grad():
_, _, self.uncond_prompt_embeds, self.uncond_prompt_attention_mask = self.text_encoding_pipeline.encode_prompt(
"",
max_sequence_length=self.args.max_diff_seq_length,
)
def compute_text_embeddings(self,prompt, text_encoding_pipeline):
with torch.no_grad():
prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(
prompt,
max_sequence_length=self.args.max_diff_seq_length,
)
return prompt_embeds, prompt_attention_mask
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
sigmas = self.noise_scheduler_copy.sigmas.to(dtype=dtype)
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device=timesteps.device)
timesteps = timesteps
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def forward(self, batch_gpu,batch, image_embeds):
args = self.args
pixel_values = batch_gpu["full_pixel_values"].to(dtype=self.vae.dtype) #aux_image
data_type = "t2i"
if len(pixel_values.shape)==5:
bs,num_img,c,h,w = pixel_values.shape
if num_img==2:
data_type = "edit"
pixel_values_ori_img = pixel_values[:,0]
pixel_values = pixel_values[:,-1]
pixel_values = F.interpolate(pixel_values, size=(self.args.diffusion_resolution, self.args.diffusion_resolution), mode='bilinear',align_corners=False)
if data_type=="edit" and self.ori_inp_dit!="none":
pixel_values_ori_img = F.interpolate(pixel_values_ori_img, size=(self.args.diffusion_resolution, self.args.diffusion_resolution), mode='bilinear', align_corners=False)
prompt = batch["prompts"]
bs,_,_,_ = pixel_values.shape
image_prompt_embeds = None
image_embeds_2d = image_embeds.reshape(bs, 24, 24, image_embeds.shape[-1]).permute(0, 3, 1, 2)
image_embeds_2d = F.interpolate(image_embeds_2d, size=(args.diffusion_resolution//8, args.diffusion_resolution//8), mode='bilinear', align_corners=False)
control_emd = args.control_emd
prompt_embeds, prompt_attention_mask = self.compute_text_embeddings(prompt, self.text_encoding_pipeline)
if control_emd=="mix":
prompt_embeds=torch.cat([prompt_embeds, image_prompt_embeds], dim=1) #use mix
elif control_emd=="null":
prompt_embeds = torch.zeros_like(prompt_embeds)
prompt_attention_mask = torch.ones_like(prompt_attention_mask)
elif control_emd=="text":
pass
elif control_emd=="vit" or control_emd=="vq" or control_emd=="vqvae" or control_emd=="vqconcat" or control_emd=="vqconcatvit":
prompt_embeds=image_prompt_embeds
latents = self.vae.encode(pixel_values).latent_dist.sample()
latents = (latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
latents = latents.to(dtype=image_embeds.dtype)
latents_ori_img = torch.zeros_like(latents)
if data_type=="edit" and self.ori_inp_dit!="none":
latents_ori_img = self.vae.encode(pixel_values_ori_img).latent_dist.sample()
latents_ori_img = (latents_ori_img - self.vae.config.shift_factor) * self.vae.config.scaling_factor
latents_ori_img = latents_ori_img.to(dtype=image_embeds.dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype).to(device=noise.device)
#noisy_model_input = (1.0 - sigmas) * noise + sigmas * latents
noisy_model_input = sigmas * noise + (1-sigmas) * latents
#noisy_model_input + (1-sigmas)*(latents - noise) = latents
# Get the additional image embedding for conditioning.
# Instead of getting a diagonal Gaussian here, we simply take the mode.
original_image_embeds = image_embeds_2d
if args.conditioning_dropout_prob is not None:
random_p = torch.rand(bsz, device=latents.device)
# Sample masks for the edit prompts.
prompt_mask = random_p < 2 * args.uncondition_prob
prompt_mask = prompt_mask.reshape(bsz, 1, 1)
# Final text conditioning.
#prompt_embeds = torch.where(prompt_mask, torch.zeros_like(prompt_embeds), prompt_embeds)
prompt_embeds = torch.where(prompt_mask, self.uncond_prompt_embeds.repeat(prompt_embeds.shape[0],1,1).to(prompt_embeds.device), prompt_embeds)
prompt_attention_mask = torch.where(prompt_mask[:,0], self.uncond_prompt_attention_mask.repeat(prompt_embeds.shape[0],1).to(prompt_embeds.device), prompt_attention_mask)
# Sample masks for the original images.
#random_p_vq = torch.rand(bsz, device=latents.device)
image_mask_dtype = original_image_embeds.dtype
image_mask = 1 - (
(random_p <= args.conditioning_dropout_prob).to(image_mask_dtype)
)
image_mask = image_mask.reshape(bsz, 1, 1, 1)
if data_type=="edit":
image_mask=0
# Final image conditioning.
original_image_embeds = image_mask * original_image_embeds
ori_latent_mask = 1 - (
(random_p >= args.uncondition_prob).to(image_mask_dtype)
* (random_p < 3 * args.uncondition_prob).to(image_mask_dtype)
)
ori_latent_mask = ori_latent_mask.reshape(bsz, 1, 1, 1)
latents_ori_img = ori_latent_mask * latents_ori_img
concatenated_noisy_latents = torch.cat([noisy_model_input, original_image_embeds], dim=1)
ref_image_hidden_states = None
if self.ori_inp_dit=="dim":
concatenated_noisy_latents = torch.cat([concatenated_noisy_latents, latents_ori_img], dim=1)
elif self.ori_inp_dit=="seq":
latents_ori_img = torch.cat([latents_ori_img, original_image_embeds], dim=1)
concatenated_noisy_latents = torch.cat([concatenated_noisy_latents, latents_ori_img], dim=2)
elif self.ori_inp_dit=="ref":
latents_ori_img = torch.cat([latents_ori_img, original_image_embeds], dim=1)
ref_image_hidden_states = latents_ori_img[:,None]
# Predict the noise residual
# scale the timesteps (reversal not needed as we used a reverse lerp above already)
timesteps = 1-timesteps / self.noise_scheduler.config.num_train_timesteps #timesteps / self.noise_scheduler.config.num_train_timesteps
model_pred = self.transformer(
hidden_states=concatenated_noisy_latents,
timestep=timesteps,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
# ref_image_hidden_states = ref_image_hidden_states,
return_dict=False,
)[0]
if self.ori_inp_dit=="seq":
model_pred = model_pred[:, :, :args.diffusion_resolution//8, :]
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
target = latents - noise
# Conditioning dropout to support classifier-free guidance during inference. For more details
# check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
# Concatenate the `original_image_embeds` with the `noisy_latents`.
# Get the target for loss depending on the prediction type
loss = torch.mean(
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1,
)
loss = loss.mean()
loss_value = loss.item()
return loss
class Lumina2InstructPix2PixPipeline(Lumina2Pipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
num_inference_steps: int = 30,
guidance_scale: float = 4.0,
negative_prompt: Union[str, List[str]] = None,
sigmas: List[float] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
system_prompt: Optional[str] = None,
cfg_trunc_ratio=[0.0,1.0],
cfg_normalization: bool = False,
max_sequence_length: int = 256,
control_emd="text",
img_cfg_trunc_ratio =[0.0,1.0],
gen_image_embeds=None,only_t2i="vqconcat",image_prompt_embeds=None,ori_inp_img=None,img_guidance_scale=1.5,vq_guidance_scale=0,ori_inp_way="none",
) -> Union[ImagePipelineOutput, Tuple]:
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
num_images_per_prompt = gen_image_embeds.shape[0] if gen_image_embeds is not None else image_prompt_embeds.shape[0]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
system_prompt=system_prompt,
)
if gen_image_embeds is not None:
image_embeds_8=gen_image_embeds
if control_emd=="text":
pass
elif control_emd=="null":
prompt_embeds = torch.zeros_like(prompt_embeds)
prompt_attention_mask = torch.zeros_like(prompt_attention_mask)
negative_prompt_embeds = prompt_embeds
negative_prompt_attention_mask = prompt_attention_mask
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds,negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask,negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare latents.
latent_channels = self.vae.config.latent_channels #self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
latents_ori_img = torch.zeros_like(latents)
if ori_inp_img is not None and ori_inp_way !="none":
#fuck = torch.load(ori_inp_img).to(latents.device)
ori_inp_img = F.interpolate(ori_inp_img[None].to(latents.device,latents.dtype), size=(height,width), mode='bilinear',align_corners=False)
latents_ori_img = self.vae.encode(ori_inp_img).latent_dist.sample()
latents_ori_img = (latents_ori_img- self.vae.config.shift_factor) * self.vae.config.scaling_factor
latents_ori_img = latents_ori_img.to(dtype=latents.dtype)
if ori_inp_way !="none":
negative_latents_ori_img = torch.zeros_like(latents_ori_img).to(prompt_embeds.dtype)
latents_ori_img = torch.cat([negative_latents_ori_img,latents_ori_img, latents_ori_img], dim=0) if self.do_classifier_free_guidance else latents_ori_img
vq_in_edit = False
if only_t2i==True:
image_latents = torch.zeros_like(latents)[:,:8]
elif only_t2i=="vqconcat":
image_embeds_2d = image_embeds_8.reshape(batch_size* num_images_per_prompt,24,24,image_embeds_8.shape[-1]).permute(0,3,1,2)
if ori_inp_img is not None and image_embeds_8.mean()!=0:
vq_in_edit = True
image_vq_latents = F.interpolate(image_embeds_2d, size=(height//8,width//8), mode='bilinear',align_corners=False).to(latents.device,latents.dtype)
image_latents = torch.zeros_like(image_vq_latents)
else:
image_latents = F.interpolate(image_embeds_2d, size=(height//8,width//8), mode='bilinear',align_corners=False).to(latents.device,latents.dtype)
negative_image_latents = torch.zeros_like(image_latents).to(prompt_embeds.dtype)
image_latents = torch.cat([negative_image_latents,image_latents, image_latents], dim=0) if self.do_classifier_free_guidance else image_latents
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
self.scheduler.sigmas=self.scheduler.sigmas.to(latents.dtype) #hjc find bug
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# compute whether apply classifier-free truncation on this timestep
do_classifier_free_truncation = not ((i + 1) / num_inference_steps > cfg_trunc_ratio[0] and (i + 1) / num_inference_steps < cfg_trunc_ratio[1])
img_do_classifier_free_truncation = not ((i + 1) / num_inference_steps > img_cfg_trunc_ratio[0] and (i + 1) / num_inference_steps < img_cfg_trunc_ratio[1])
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
current_timestep = 1 - t / self.scheduler.config.num_train_timesteps
latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
ref_image_hidden_states = None
if ori_inp_way=="seq":
latents_ori_img_cat = torch.cat([latents_ori_img, image_latents], dim=1)
latent_model_input = torch.cat([latent_model_input, latents_ori_img_cat], dim=2)
elif ori_inp_way=="ref":
latents_ori_img_cat = torch.cat([latents_ori_img, image_latents], dim=1)
ref_image_hidden_states = latents_ori_img_cat[:,None]
if ori_inp_way=="ref":
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=current_timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
return_dict=False,ref_image_hidden_states=ref_image_hidden_states,
attention_kwargs=self.attention_kwargs,
)[0]
else:
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=current_timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
return_dict=False,
attention_kwargs=self.attention_kwargs,
)[0]
if ori_inp_way=="seq":
noise_pred = noise_pred[:,:,:height//8,:]
if vq_in_edit:
latent_model_vq_input = torch.cat([latents, image_vq_latents], dim=1)
if ori_inp_way=="seq":
latents_ori_img_cat_vq = torch.cat([torch.zeros_like(latents), image_vq_latents], dim=1)
latent_model_vq_input = torch.cat([latent_model_vq_input, latents_ori_img_cat_vq], dim=2)
noise_vq_pred = self.transformer(
hidden_states=latent_model_vq_input,
timestep=current_timestep[-1:],
encoder_hidden_states=prompt_embeds[-1:],
encoder_attention_mask=prompt_attention_mask[-1:],
return_dict=False,
attention_kwargs=self.attention_kwargs,
)[0]
if ori_inp_way=="seq":
noise_vq_pred = noise_vq_pred[:,:,:height//8,:]
# perform normalization-based guidance scale on a truncated timestep interval
if self.do_classifier_free_guidance:
noise_pred_uncond,noise_pred_img, noise_pred_text = noise_pred.chunk(3)
if not do_classifier_free_truncation and not img_do_classifier_free_truncation:
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_img)+ img_guidance_scale * (noise_pred_img - noise_pred_uncond)
elif not do_classifier_free_truncation and img_do_classifier_free_truncation:
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_img)+ 1 * (noise_pred_img - noise_pred_uncond)
elif do_classifier_free_truncation and not img_do_classifier_free_truncation:
noise_pred = noise_pred_uncond + 1 * (noise_pred_text - noise_pred_img)+ img_guidance_scale * (noise_pred_img - noise_pred_uncond)
else:
noise_pred = noise_pred_text
if vq_in_edit:
noise_pred = noise_pred +vq_guidance_scale*(noise_vq_pred-noise_pred_uncond)
# apply normalization after classifier-free guidance
if cfg_normalization:
cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True)
noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_pred = noise_pred * (cond_norm / noise_norm)
else:
noise_pred = noise_pred
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
noise_pred = -noise_pred
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
else:
image = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)