build-tools / diffusers /pipelines /longcat_image /pipeline_longcat_image_edit.py
salmankhanpm's picture
Add files using upload-large-folder tool
4f4376a verified
# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
import re
from typing import Any
import numpy as np
import PIL
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import LongCatImageTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import LongCatImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from PIL import Image
>>> import torch
>>> from diffusers import LongCatImageEditPipeline
>>> pipe = LongCatImageEditPipeline.from_pretrained(
... "meituan-longcat/LongCat-Image-Edit", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> prompt = "change the cat to dog."
>>> input_image = Image.open("test.jpg").convert("RGB")
>>> image = pipe(
... input_image,
... prompt,
... num_inference_steps=50,
... guidance_scale=4.5,
... generator=torch.Generator("cpu").manual_seed(43),
... ).images[0]
>>> image.save("longcat_image_edit.png")
```
"""
# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.split_quotation
def split_quotation(prompt, quote_pairs=None):
"""
Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote
pairs. Examples::
>>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) >>> #
output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)]
"""
word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+")
matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt)
mapping_word_internal_quote = []
for i, word_src in enumerate(set(matches_word_internal_quote_pattern)):
word_tgt = "longcat_$##$_longcat" * (i + 1)
prompt = prompt.replace(word_src, word_tgt)
mapping_word_internal_quote.append([word_src, word_tgt])
if quote_pairs is None:
quote_pairs = [("'", "'"), ('"', '"'), ("‘", "’"), ("“", "”")]
pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs])
parts = re.split(f"({pattern})", prompt)
result = []
for part in parts:
for word_src, word_tgt in mapping_word_internal_quote:
part = part.replace(word_tgt, word_src)
if re.match(pattern, part):
if len(part):
result.append((part, True))
else:
if len(part):
result.append((part, False))
return result
# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.prepare_pos_ids
def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None):
if type == "text":
assert num_token
if height or width:
print('Warning: The parameters of height and width will be ignored in "text" type.')
pos_ids = torch.zeros(num_token, 3)
pos_ids[..., 0] = modality_id
pos_ids[..., 1] = torch.arange(num_token) + start[0]
pos_ids[..., 2] = torch.arange(num_token) + start[1]
elif type == "image":
assert height and width
if num_token:
print('Warning: The parameter of num_token will be ignored in "image" type.')
pos_ids = torch.zeros(height, width, 3)
pos_ids[..., 0] = modality_id
pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0]
pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1]
pos_ids = pos_ids.reshape(height * width, 3)
else:
raise KeyError(f'Unknow type {type}, only support "text" or "image".')
return pos_ids
# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: int | None = None,
device: str | torch.device | None = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`list[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`list[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def calculate_dimensions(target_area, ratio):
width = math.sqrt(target_area * ratio)
height = width / ratio
width = width if width % 16 == 0 else (width // 16 + 1) * 16
height = height if height % 16 == 0 else (height // 16 + 1) * 16
width = int(width)
height = int(height)
return width, height
class LongCatImageEditPipeline(DiffusionPipeline, FromSingleFileMixin):
r"""
The LongCat-Image-Edit pipeline for image editing.
"""
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: Qwen2Tokenizer,
text_processor: Qwen2VLProcessor,
transformer: LongCatImageTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
text_processor=text_processor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.image_processor_vl = text_processor.image_processor
self.image_token = "<|image_pad|>"
self.prompt_template_encode_prefix = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n"
self.default_sample_size = 128
self.tokenizer_max_length = 512
def _encode_prompt(self, prompt, image):
raw_vl_input = self.image_processor_vl(images=image, return_tensors="pt")
pixel_values = raw_vl_input["pixel_values"]
image_grid_thw = raw_vl_input["image_grid_thw"]
all_tokens = []
for clean_prompt_sub, matched in split_quotation(prompt[0]):
if matched:
for sub_word in clean_prompt_sub:
tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"]
all_tokens.extend(tokens)
else:
tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"]
all_tokens.extend(tokens)
if len(all_tokens) > self.tokenizer_max_length:
logger.warning(
"Your input was truncated because `max_sequence_length` is set to "
f" {self.tokenizer_max_length} input token nums : {len(len(all_tokens))}"
)
all_tokens = all_tokens[: self.tokenizer_max_length]
text_tokens_and_mask = self.tokenizer.pad(
{"input_ids": [all_tokens]},
max_length=self.tokenizer_max_length,
padding="max_length",
return_attention_mask=True,
return_tensors="pt",
)
text = self.prompt_template_encode_prefix
merge_length = self.image_processor_vl.merge_size**2
while self.image_token in text:
num_image_tokens = image_grid_thw.prod() // merge_length
text = text.replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
text = text.replace("<|placeholder|>", self.image_token)
prefix_tokens = self.tokenizer(text, add_special_tokens=False)["input_ids"]
suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"]
vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>")
prefix_len = prefix_tokens.index(vision_start_token_id)
suffix_len = len(suffix_tokens)
prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype)
suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype)
prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype)
suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype)
input_ids = torch.cat((prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1)
attention_mask = torch.cat(
(prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1
)
input_ids = input_ids.unsqueeze(0).to(self.device)
attention_mask = attention_mask.unsqueeze(0).to(self.device)
pixel_values = pixel_values.to(self.device)
image_grid_thw = image_grid_thw.to(self.device)
text_output = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
output_hidden_states=True,
)
# [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size]
# clone to have a contiguous tensor
prompt_embeds = text_output.hidden_states[-1].detach()
prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :]
return prompt_embeds
def encode_prompt(
self,
prompt: list[str] = None,
image: torch.Tensor | None = None,
num_images_per_prompt: int | None = 1,
prompt_embeds: torch.Tensor | None = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
# If prompt_embeds is provided and prompt is None, skip encoding
if prompt_embeds is None:
prompt_embeds = self._encode_prompt(prompt, image)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to(
self.device
)
return prompt_embeds, text_ids
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
return image_latents
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
def prepare_latents(
self,
image,
batch_size,
num_channels_latents,
height,
width,
dtype,
prompt_embeds_length,
device,
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
image_latents, image_latents_ids = None, None
if image is not None:
image = image.to(device=self.device, dtype=dtype)
if image.shape[1] != self.vae.config.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
image_latents_ids = prepare_pos_ids(
modality_id=2,
type="image",
start=(prompt_embeds_length, prompt_embeds_length),
height=height // 2,
width=width // 2,
).to(device, dtype=torch.float64)
shape = (batch_size, num_channels_latents, height, width)
latents_ids = prepare_pos_ids(
modality_id=1,
type="image",
start=(prompt_embeds_length, prompt_embeds_length),
height=height // 2,
width=width // 2,
).to(device)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
else:
latents = latents.to(device=device, dtype=dtype)
return latents, image_latents, latents_ids, image_latents_ids
@property
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
def check_inputs(
self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None:
if isinstance(prompt, str):
pass
elif isinstance(prompt, list) and len(prompt) == 1:
pass
else:
raise ValueError(
f"`prompt` must be a `str` or a `list` of length 1, but is {prompt} (type: {type(prompt)})"
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
@replace_example_docstring(EXAMPLE_DOC_STRING)
@torch.no_grad()
def __call__(
self,
image: PIL.Image.Image | None = None,
prompt: str | list[str] = None,
negative_prompt: str | list[str] = None,
num_inference_steps: int = 50,
sigmas: list[float] | None = None,
guidance_scale: float = 4.5,
num_images_per_prompt: int | None = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.FloatTensor | None = None,
prompt_embeds: torch.FloatTensor | None = None,
negative_prompt_embeds: torch.FloatTensor | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
joint_attention_kwargs: dict[str, Any] | None = None,
):
r"""
Function invoked when calling the pipeline for generation.
Examples:
Returns:
[`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
image_size = image[0].size if isinstance(image, list) else image.size
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] * 1.0 / image_size[1])
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
calculated_height,
calculated_width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Preprocess image
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
image = self.image_processor.resize(image, calculated_height, calculated_width)
prompt_image = self.image_processor.resize(image, calculated_height // 2, calculated_width // 2)
image = self.image_processor.preprocess(image, calculated_height, calculated_width)
negative_prompt = "" if negative_prompt is None else negative_prompt
(prompt_embeds, text_ids) = self.encode_prompt(
prompt=prompt, image=prompt_image, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt
)
if self.do_classifier_free_guidance:
(negative_prompt_embeds, negative_text_ids) = self.encode_prompt(
prompt=negative_prompt,
image=prompt_image,
prompt_embeds=negative_prompt_embeds,
num_images_per_prompt=num_images_per_prompt,
)
# 4. Prepare latent variables
num_channels_latents = 16
latents, image_latents, latents_ids, image_latents_ids = self.prepare_latents(
image,
batch_size * num_images_per_prompt,
num_channels_latents,
calculated_height,
calculated_width,
prompt_embeds.dtype,
prompt_embeds.shape[1],
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
guidance = None
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
if image is not None:
latent_image_ids = torch.cat([latents_ids, image_latents_ids], dim=0)
else:
latent_image_ids = latents_ids
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1)
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
with self.transformer.cache_context("cond"):
noise_pred_text = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
noise_pred_text = noise_pred_text[:, :image_seq_len]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
noise_pred_uncond = noise_pred_uncond[:, :image_seq_len]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
noise_pred = noise_pred_text
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, calculated_height, calculated_width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
if latents.dtype != self.vae.dtype:
latents = latents.to(dtype=self.vae.dtype)
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return LongCatImagePipelineOutput(images=image)