StyleExper-V2 / src /pipeline.py
oedevs's picture
upload
56d35ce
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from diffusers import FluxKontextPipeline
from diffusers.image_processor import (VaeImageProcessor)
from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from torchvision.transforms.functional import pad
from .transformer_flux import FluxTransformer2DModel
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
PREFERRED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
def prepare_latent_image_ids_2(height, width, device, dtype):
latent_image_ids = torch.zeros(height//2, width//2, 3, device=device, dtype=dtype)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height//2, device=device)[:, None] # y坐标
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width//2, device=device)[None, :] # x坐标
return latent_image_ids
def prepare_latent_subject_ids(height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3, device=device, dtype=dtype)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2, device=device)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2, device=device)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
def resize_position_encoding(batch_size, original_height, original_width, target_height, target_width, device, dtype):
latent_image_ids = prepare_latent_image_ids_2(original_height, original_width, device, dtype)
scale_h = original_height / target_height
scale_w = original_width / target_width
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
#spatial进行PE插值
latent_image_ids_resized = torch.zeros(target_height//2, target_width//2, 3, device=device, dtype=dtype)
for i in range(target_height//2):
for j in range(target_width//2):
latent_image_ids_resized[i, j, 1] = i*scale_h
latent_image_ids_resized[i, j, 2] = j*scale_w
cond_latent_image_id_height, cond_latent_image_id_width, cond_latent_image_id_channels = latent_image_ids_resized.shape
cond_latent_image_ids = latent_image_ids_resized.reshape(
cond_latent_image_id_height * cond_latent_image_id_width, cond_latent_image_id_channels
)
# latent_image_ids_ = torch.concat([latent_image_ids, cond_latent_image_ids], dim=0)
return latent_image_ids, cond_latent_image_ids #, latent_image_ids_
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds_input_ids(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
ret_input_ids = False,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if ret_input_ids:
return prompt_embeds, text_input_ids
return prompt_embeds
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
def encode_prompt_input_ids(
self,
prompt: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
ret_input_ids=False,
):
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# We only use the pooled prompt output from the CLIPTextModel
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds, input_ids = _get_t5_prompt_embeds_input_ids(
self,
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
ret_input_ids=True
)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
if ret_input_ids:
input_ids = input_ids.to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids, input_ids
return prompt_embeds, pooled_prompt_embeds, text_ids
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
def set_moe_layers_latents(
subject_images,
sty_encoder,
siglip_processor,
siglip_model,
moe_layers = None,
):
with torch.no_grad():
inputs = siglip_processor(images=subject_images, return_tensors="pt").to(siglip_model.device)
siglip_feats = siglip_model(**inputs, output_hidden_states=True)
# style_feats = siglip_model(**inputs).pooler_output
latents = sty_encoder(siglip_feats).flatten(1)
cond_hidden_states = latents
for layer in moe_layers:
layer.set_latents(cond_hidden_states=cond_hidden_states)
def insert_style_tokens(
prompt_embeds,
sty_token_id, con_token_id, sty_ori_token_id,
sty_tokens,
text_input_ids, text_ids
):
def insert_tokens(prompt_embed: torch.Tensor, sty_token: torch.Tensor, index: int) -> torch.Tensor:
if sty_token.dim() == 1: # (hidden_dim,)
sty_token = sty_token.unsqueeze(0) # (1, hidden_dim)
if sty_token.dim() == 2: # (1, hidden_dim)
sty_token = sty_token.unsqueeze(0) # (1, 1, hidden_dim)
before = prompt_embed[:, :index, :]
after = prompt_embed[:, index:, :]
new_prompt_embed = torch.cat([before, sty_token, after], dim=1)
return new_prompt_embed
new_prompt_embeds = []
for i in range(len(prompt_embeds)):
input_ids = text_input_ids[i]
sty_token_index = -1
for index, token_id in enumerate(input_ids.tolist()):
if token_id == sty_token_id:
sty_token_index = index
break
prompt_embed = prompt_embeds[i]
prompt_embed = prompt_embed.unsqueeze(0)
prompt_embed = insert_tokens(prompt_embed, sty_tokens, sty_token_index)
# sty_token_mask = [True if sty_token_index <= i < sty_token_index+1 else False for i in range(prompt_embeds.shape[1])]
# sty_token_mask = torch.tensor(sty_token_mask, dtype=torch.bool).unsqueeze(0).to(accelerator.device)
# updated_embed = photo_encoder(cond_A_pixel_value, prompt_embed, sty_token_mask)
new_prompt_embeds.append(prompt_embed)
prompt_embeds = torch.cat(new_prompt_embeds, dim=0)
style_len = sty_tokens.shape[1]
text_ids = torch.cat([text_ids, torch.zeros(style_len, 3, device=text_ids.device)])
return prompt_embeds, text_ids
from .moe import param_CondLoRAMoELayer
class myKontextPipeline(FluxKontextPipeline):
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
subject_image,
condition_image,
latents=None,
cond_number=1,
sub_number=1,
):
height_cond = 2 * (self.cond_size // (self.vae_scale_factor*2))
width_cond = 2 * (self.cond_size // (self.vae_scale_factor*2))
height = 2 * (int(height) // (self.vae_scale_factor*2))
width = 2 * (int(width) // (self.vae_scale_factor*2))
shape = (batch_size, num_channels_latents, height, width) # 1 16 106 80
noise_latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
noise_latent_image_ids = self._prepare_latent_image_ids(
noise_latents.shape[0],
noise_latents.shape[2] // 2,
noise_latents.shape[3] // 2,
device,
dtype,
)
noise_latents = self._pack_latents(noise_latents, batch_size, num_channels_latents, height, width)
latents_to_concat = [] # 不包含 latents
latents_ids_to_concat = [noise_latent_image_ids]
# spatial
if condition_image is not None:
cond_number = 1
B, N, C, H, W = condition_image.shape # 1, 3, 3, 512, 512
condition_image = condition_image.view(B * N, C, H, W).to(dtype=dtype)
condition_image = condition_image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=condition_image, generator=generator)
cond_latent_image_ids = self._prepare_latent_image_ids(
image_latents.shape[0],
image_latents.shape[2] // 2,
image_latents.shape[3] // 2,
device,
dtype,
)
cond_latents = self._pack_latents(image_latents, B*N, num_channels_latents, height_cond*cond_number, width_cond)
# cond_latents = self.con_encoder(cond_latents) # 新增
cond_latents = cond_latents.view(B, N, *cond_latents.shape[1:])
cond_latents = cond_latents.mean(dim=1)
# print("In pipeline, through con_encoder")
cond_latent_image_ids = torch.concat([cond_latent_image_ids for _ in range(cond_number)], dim=-2)
cond_latent_image_ids[..., 0] = 1
latents_ids_to_concat.append(cond_latent_image_ids)
latents_to_concat.append(cond_latents)
# subject
if subject_image is not None and getattr(self, "style_token_concat", True):
sub_number = 1
B, N, C, H, W = subject_image.shape # 1, 3, 3, 512, 512
subject_image = subject_image.view(B * N, C, H, W).to(dtype=dtype)
subject_image = subject_image.to(device=device, dtype=dtype)
subject_image_latents = self._encode_vae_image(image=subject_image, generator=generator)
if getattr(self, "inference_args", None):
style_multi = self.inference_args.style_multi if self.inference_args.style_multi else 1
subject_image_latents = subject_image_latents * style_multi
latent_subject_ids = self._prepare_latent_image_ids(
subject_image_latents.shape[0],
subject_image_latents.shape[2] // 2,
subject_image_latents.shape[3] // 2,
device,
dtype,
)
image_latent_height, image_latent_width = subject_image_latents.shape[2:]
subject_latents = self._pack_latents(subject_image_latents, B*N, num_channels_latents, image_latent_height*sub_number, image_latent_width)
# subject_latents = self.sty_encoder(subject_latents) # 新增
subject_latents = subject_latents.view(B, N, *subject_latents.shape[1:])
subject_latents = subject_latents.mean(dim=1)
# print("In pipeline, through sty_encoder")
# latent_subject_ids = prepare_latent_subject_ids(height_cond, width_cond, device, dtype)
if hasattr(self, "style_offset") and self.style_offset:
latent_subject_ids[:, 1] += 64
latent_subject_ids[..., 0] = 2
subject_latent_image_ids = torch.concat([latent_subject_ids for _ in range(sub_number)], dim=-2)
latents_to_concat.append(subject_latents)
latents_ids_to_concat.append(subject_latent_image_ids)
cond_latents = torch.concat(latents_to_concat, dim=1)
latent_image_ids = torch.concat(latents_ids_to_concat, dim=0)
return cond_latents, latent_image_ids, noise_latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
true_cfg_scale: float = 1.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
max_area: int = 1024**2,
_auto_resize: bool = True,
spatial_images=None,
subject_images=None,
cond_size=1024,
):
self.cond_size = cond_size
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
original_height, original_width = height, width
aspect_ratio = width / height
width = round((max_area * aspect_ratio) ** 0.5)
height = round((max_area / aspect_ratio) ** 0.5)
multiple_of = self.vae_scale_factor * 2
width = width // multiple_of * multiple_of
height = height // multiple_of * multiple_of
if height != original_height or width != original_width:
logger.warning(
f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._current_timestep = None
self._interrupt = False
cond_number = len(spatial_images) if spatial_images else 0
sub_number = len(subject_images) if subject_images else 0
def process_image(image):
img = image[0] if isinstance(image, list) else image
image_height, image_width = self.image_processor.get_default_height_width(img)
aspect_ratio = image_width / image_height
if _auto_resize:
# Kontext is trained on specific resolutions, using one of them is recommended
_, image_width, image_height = min(
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
)
image_width = image_width // multiple_of * multiple_of
image_height = image_height // multiple_of * multiple_of
image = self.image_processor.resize(image, image_height, image_width)
image = self.image_processor.preprocess(image, image_height, image_width)
return image
if sub_number > 0:
subject_image_ls = []
for subject_image in subject_images:
subject_image_ls.append(process_image(subject_image))
subject_image = torch.stack(subject_image_ls, dim=1)
else:
subject_image = None
if cond_number > 0:
condition_image_ls = []
for img in spatial_images:
# condition_image = self.image_processor.preprocess(img, height=self.cond_size, width=self.cond_size)
# condition_image = condition_image.to(dtype=torch.float32)
condition_image_ls.append(process_image(img))
condition_image = torch.stack(condition_image_ls, dim=1)
else:
condition_image = None
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_true_cfg:
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
negative_text_ids,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
# latents, image_latents, latent_ids, image_ids = self.prepare_latents(
# image,
# batch_size * num_images_per_prompt,
# num_channels_latents,
# height,
# width,
# prompt_embeds.dtype,
# device,
# generator,
# latents,
# )
cond_latents, latent_ids, latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
subject_image,
condition_image,
latents,
cond_number,
sub_number
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
image_embeds = None
negative_image_embeds = None
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
negative_ip_adapter_image,
negative_ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
# 6. Denoising loop
# We set the index here to remove DtoH sync, helpful especially during compilation.
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
latent_model_input = torch.cat([latents, cond_latents], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
class MoEKontextPipeline(myKontextPipeline):
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
_optional_components = [
"image_encoder",
"feature_extractor",
]
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
image_encoder = None,
feature_extractor = None,
# more
extra_modules = None,
extra_items = None
# siglip_processor=None,
# siglip_model=None,
# sty_encoder=None,
# sty_token_encoder=None,
# con_token_id=None,
# sty_token_id=None,
# sty_ori_token_id=None,
):
super().__init__(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
transformer=transformer,
image_encoder = image_encoder,
feature_extractor = feature_extractor,
)
self.sty_encoder = extra_modules.sty_encoder
self.sty_token_encoder = extra_modules.get_module("sty_token_encoder")
self.siglip_processor = extra_items.siglip_processor
self.siglip_model = extra_items.siglip_model
self.con_token_id = extra_items.con_token_id
self.sty_token_id = extra_items.sty_token_id
self.sty_ori_token_id = extra_items.sty_ori_token_id
self.style_token_concat = extra_items.style_token_concat or False
self.style_offset = extra_items.style_offset
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
true_cfg_scale: float = 1.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
max_area: int = 1024**2,
_auto_resize: bool = True,
spatial_images=None,
subject_images=None,
cond_size=1024,
get_topk_indices=False,
):
self.cond_size = cond_size
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
original_height, original_width = height, width
aspect_ratio = width / height
width = round((max_area * aspect_ratio) ** 0.5)
height = round((max_area / aspect_ratio) ** 0.5)
multiple_of = self.vae_scale_factor * 2
width = width // multiple_of * multiple_of
height = height // multiple_of * multiple_of
if height != original_height or width != original_width:
logger.warning(
f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._current_timestep = None
self._interrupt = False
cond_number = len(spatial_images) if spatial_images else 0
sub_number = len(subject_images) if subject_images else 0
def process_image(image):
img = image[0] if isinstance(image, list) else image
image_height, image_width = self.image_processor.get_default_height_width(img)
aspect_ratio = image_width / image_height
if _auto_resize:
# Kontext is trained on specific resolutions, using one of them is recommended
_, image_width, image_height = min(
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
)
image_width = image_width // multiple_of * multiple_of
image_height = image_height // multiple_of * multiple_of
image = self.image_processor.resize(image, image_height, image_width)
image = self.image_processor.preprocess(image, image_height, image_width)
return image
if sub_number > 0:
subject_image_ls = []
for subject_image in subject_images:
subject_image_ls.append(process_image(subject_image))
subject_image = torch.stack(subject_image_ls, dim=1)
else:
subject_image = None
if cond_number > 0:
condition_image_ls = []
for img in spatial_images:
# condition_image = self.image_processor.preprocess(img, height=self.cond_size, width=self.cond_size)
# condition_image = condition_image.to(dtype=torch.float32)
condition_image_ls.append(process_image(img))
condition_image = torch.stack(condition_image_ls, dim=1)
else:
condition_image = None
moe_layers = [
module for name, module in self.transformer.named_modules()
if isinstance(module, param_CondLoRAMoELayer)
]
if sub_number > 0 and len(moe_layers) > 0: # 暂时先1个
set_moe_layers_latents(
subject_images[0],
self.sty_encoder,
self.siglip_processor,
self.siglip_model,
moe_layers,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
input_ids,
) = encode_prompt_input_ids(
self,
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
ret_input_ids=True
)
if do_true_cfg:
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
negative_text_ids,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if sub_number > 0 and self.sty_token_encoder: # 暂时先1个
inputs = self.siglip_processor(images=subject_images[0], return_tensors="pt").to(self.siglip_model.device)
with torch.no_grad():
style_feats = self.siglip_model(**inputs, output_hidden_states=True)
sty_tokens = self.sty_token_encoder(style_feats).to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
prompt_embeds, text_ids = insert_style_tokens(
prompt_embeds,
self.sty_token_id, self.con_token_id, self.sty_ori_token_id,
sty_tokens,
input_ids, text_ids
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
# latents, image_latents, latent_ids, image_ids = self.prepare_latents(
# image,
# batch_size * num_images_per_prompt,
# num_channels_latents,
# height,
# width,
# prompt_embeds.dtype,
# device,
# generator,
# latents,
# )
cond_latents, latent_ids, latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
subject_image,
condition_image,
latents,
cond_number,
sub_number
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
image_embeds = None
negative_image_embeds = None
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
negative_ip_adapter_image,
negative_ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
# 6. Denoising loop
# We set the index here to remove DtoH sync, helpful especially during compilation.
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
latent_model_input = torch.cat([latents, cond_latents], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if get_topk_indices:
topk_indices = []
for layer in moe_layers:
topk_indices.append(layer.top_k_idx)
return topk_indices
noise_pred = noise_pred[:, : latents.size(1)]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
for layer in moe_layers:
layer.clear_latents()
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)