Spaces:
Runtime error
Runtime error
| import inspect | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| import numpy as np | |
| import PIL.Image | |
| from PIL import Image | |
| import torch | |
| from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
| from diffusers.image_processor import PipelineImageInput, VaeImageProcessor | |
| from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin | |
| from diffusers.models.autoencoders import AutoencoderKL | |
| from diffusers.models.transformers import FluxTransformer2DModel | |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.utils import ( | |
| USE_PEFT_BACKEND, | |
| is_torch_xla_available, | |
| logging, | |
| replace_example_docstring, | |
| scale_lora_layers, | |
| unscale_lora_layers, | |
| ) | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
| from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput | |
| from diffusers.pipelines import FluxInpaintPipeline | |
| from diffusers.pipelines.flux.pipeline_flux_inpaint import calculate_shift, retrieve_latents, retrieve_timesteps | |
| class FluxTryonPipeline(FluxInpaintPipeline): | |
| # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids | |
| def _prepare_latent_image_ids(batch_size, height, width, device, dtype, target_width=-1, tryon=False): | |
| latent_image_ids = torch.zeros(height, width, 3) | |
| if target_width==-1: | |
| latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] | |
| latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] | |
| else: | |
| latent_image_ids[:, target_width:, 0] = 1 | |
| # height keep as before | |
| latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] | |
| if tryon: | |
| latent_image_ids[:, target_width*2:, 0] = 2 | |
| # left | |
| latent_image_ids[:, :target_width, 2] = latent_image_ids[:, :target_width, 2] + torch.arange(target_width)[None, :] | |
| # right | |
| latent_image_ids[:, target_width:, 2] = latent_image_ids[:, target_width:, 2] + torch.arange(width-target_width)[None, :] | |
| else: | |
| latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] | |
| latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape | |
| latent_image_ids = latent_image_ids.reshape( | |
| latent_image_id_height * latent_image_id_width, latent_image_id_channels | |
| ) | |
| return latent_image_ids.to(device=device, dtype=dtype) | |
| def prepare_latents( | |
| self, | |
| image, | |
| timestep, | |
| batch_size, | |
| num_channels_latents, | |
| height, | |
| width, | |
| target_width, | |
| tryon, | |
| dtype, | |
| device, | |
| generator, | |
| latents=None, | |
| ): | |
| if isinstance(generator, list) and len(generator) != batch_size: | |
| raise ValueError( | |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
| ) | |
| # VAE applies 8x compression on images but we must also account for packing which requires | |
| # latent height and width to be divisible by 2. | |
| height = 2 * (int(height) // (self.vae_scale_factor * 2)) | |
| width = 2 * (int(width) // (self.vae_scale_factor * 2)) | |
| shape = (batch_size, num_channels_latents, height, width) | |
| sp = 2 * (int(target_width) // (self.vae_scale_factor * 2))//2 # -1 | |
| latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype, sp, tryon) | |
| image = image.to(device=device, dtype=dtype) | |
| # image_latents = self._encode_vae_image(image=image, generator=generator) | |
| img_parts = [image[:,:,:,:target_width], image[:,:,:,target_width:]] | |
| image_latents = [self._encode_vae_image(image=img, generator=generator) for img in img_parts] | |
| image_latents = torch.cat(image_latents, dim=-1) | |
| if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: | |
| # expand init_latents for batch_size | |
| additional_image_per_prompt = batch_size // image_latents.shape[0] | |
| image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) | |
| elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: | |
| raise ValueError( | |
| f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." | |
| ) | |
| else: | |
| image_latents = torch.cat([image_latents], dim=0) | |
| if latents is None: | |
| noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
| latents = self.scheduler.scale_noise(image_latents, timestep, noise) | |
| else: | |
| noise = latents.to(device) | |
| latents = noise | |
| noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) | |
| image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) | |
| latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) | |
| return latents, noise, image_latents, latent_image_ids | |
| def prepare_mask_latents( | |
| self, | |
| mask, | |
| masked_image, | |
| batch_size, | |
| num_channels_latents, | |
| num_images_per_prompt, | |
| height, | |
| width, | |
| dtype, | |
| device, | |
| generator, | |
| ): | |
| # VAE applies 8x compression on images but we must also account for packing which requires | |
| # latent height and width to be divisible by 2. | |
| height = 2 * (int(height) // (self.vae_scale_factor * 2)) | |
| width = 2 * (int(width) // (self.vae_scale_factor * 2)) | |
| # resize the mask to latents shape as we concatenate the mask to the latents | |
| # we do that before converting to dtype to avoid breaking in case we're using cpu_offload | |
| # and half precision | |
| mask = torch.nn.functional.interpolate(mask, size=(height, width), mode="nearest") | |
| mask = mask.to(device=device, dtype=dtype) | |
| batch_size = batch_size * num_images_per_prompt | |
| masked_image = masked_image.to(device=device, dtype=dtype) | |
| if masked_image.shape[1] == 16: | |
| masked_image_latents = masked_image | |
| else: | |
| masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) | |
| masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor | |
| # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method | |
| if mask.shape[0] < batch_size: | |
| if not batch_size % mask.shape[0] == 0: | |
| raise ValueError( | |
| "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" | |
| f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" | |
| " of masks that you pass is divisible by the total requested batch size." | |
| ) | |
| mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) | |
| if masked_image_latents.shape[0] < batch_size: | |
| if not batch_size % masked_image_latents.shape[0] == 0: | |
| raise ValueError( | |
| "The passed images and the required batch size don't match. Images are supposed to be duplicated" | |
| f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." | |
| " Make sure the number of images that you pass is divisible by the total requested batch size." | |
| ) | |
| masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) | |
| # aligning device to prevent device errors when concating it with the latent model input | |
| masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) | |
| masked_image_latents = self._pack_latents( | |
| masked_image_latents, | |
| batch_size, | |
| num_channels_latents, | |
| height, | |
| width, | |
| ) | |
| mask = self._pack_latents( | |
| mask.repeat(1, num_channels_latents, 1, 1), | |
| batch_size, | |
| num_channels_latents, | |
| height, | |
| width, | |
| ) | |
| return mask, masked_image_latents | |
| def __call__( | |
| self, | |
| prompt: Union[str, List[str]] = None, | |
| prompt_2: Optional[Union[str, List[str]]] = None, | |
| image: PipelineImageInput = None, | |
| mask_image: PipelineImageInput = None, | |
| masked_image_latents: PipelineImageInput = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| target_width: Optional[int] = None, | |
| tryon: bool = False, | |
| padding_mask_crop: Optional[int] = None, | |
| strength: float = 0.6, | |
| num_inference_steps: int = 28, | |
| timesteps: List[int] = None, | |
| guidance_scale: float = 7.0, | |
| num_images_per_prompt: Optional[int] = 1, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
| max_sequence_length: int = 512, | |
| ): | |
| height = height or self.default_sample_size * self.vae_scale_factor | |
| width = width or self.default_sample_size * self.vae_scale_factor | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| prompt, | |
| prompt_2, | |
| image, | |
| mask_image, | |
| strength, | |
| height, | |
| width, | |
| output_type=output_type, | |
| prompt_embeds=prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, | |
| padding_mask_crop=padding_mask_crop, | |
| max_sequence_length=max_sequence_length, | |
| ) | |
| self._guidance_scale = guidance_scale | |
| self._joint_attention_kwargs = joint_attention_kwargs | |
| self._interrupt = False | |
| # 2. Preprocess mask and image | |
| if padding_mask_crop is not None: | |
| crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) | |
| resize_mode = "fill" | |
| else: | |
| crops_coords = None | |
| resize_mode = "default" | |
| original_image = image | |
| init_image = self.image_processor.preprocess( | |
| image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode | |
| ) | |
| init_image = init_image.to(dtype=torch.float32) | |
| # 3. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| device = self._execution_device | |
| lora_scale = ( | |
| self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None | |
| ) | |
| ( | |
| prompt_embeds, | |
| pooled_prompt_embeds, | |
| text_ids, | |
| ) = self.encode_prompt( | |
| prompt=prompt, | |
| prompt_2=prompt_2, | |
| prompt_embeds=prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| device=device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| max_sequence_length=max_sequence_length, | |
| lora_scale=lora_scale, | |
| ) | |
| # 4.Prepare timesteps | |
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
| image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) | |
| mu = calculate_shift( | |
| image_seq_len, | |
| self.scheduler.config.base_image_seq_len, | |
| self.scheduler.config.max_image_seq_len, | |
| self.scheduler.config.base_shift, | |
| self.scheduler.config.max_shift, | |
| ) | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| self.scheduler, | |
| num_inference_steps, | |
| device, | |
| timesteps, | |
| sigmas, | |
| mu=mu, | |
| ) | |
| timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) | |
| if num_inference_steps < 1: | |
| raise ValueError( | |
| f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" | |
| f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." | |
| ) | |
| latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | |
| # 5. Prepare latent variables | |
| num_channels_latents = self.transformer.config.in_channels // 4 | |
| num_channels_transformer = self.transformer.config.in_channels | |
| latents, noise, image_latents, latent_image_ids= self.prepare_latents( | |
| init_image, | |
| latent_timestep, | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| target_width, | |
| tryon, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| mask_condition = self.mask_processor.preprocess( | |
| mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords | |
| ) | |
| if masked_image_latents is None: | |
| masked_image = init_image * (mask_condition < 0.5) | |
| else: | |
| masked_image = masked_image_latents | |
| mask, masked_image_latents = self.prepare_mask_latents( | |
| mask_condition, | |
| masked_image, | |
| batch_size, | |
| num_channels_latents, | |
| num_images_per_prompt, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| ) | |
| num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) | |
| self._num_timesteps = len(timesteps) | |
| # handle guidance | |
| if self.transformer.config.guidance_embeds: | |
| guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) | |
| guidance = guidance.expand(latents.shape[0]) | |
| else: | |
| guidance = None | |
| # 6. Denoising loop | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| if self.interrupt: | |
| continue | |
| # for 64 channel transformer only. | |
| init_latents_proper = image_latents | |
| init_mask = mask | |
| latents = (1 - init_mask) * init_latents_proper + init_mask * latents | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timestep = t.expand(latents.shape[0]).to(latents.dtype) | |
| noise_pred = self.transformer( | |
| hidden_states=latents, | |
| timestep=timestep / 1000, | |
| guidance=guidance, | |
| pooled_projections=pooled_prompt_embeds, | |
| encoder_hidden_states=prompt_embeds, | |
| txt_ids=text_ids, | |
| img_ids=latent_image_ids, | |
| joint_attention_kwargs=self.joint_attention_kwargs, | |
| return_dict=False, | |
| )[0] | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents_dtype = latents.dtype | |
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| ''' | |
| # for 64 channel transformer only. | |
| init_latents_proper = image_latents | |
| init_mask = mask | |
| # NOTE: we just use clean latents | |
| # if i < len(timesteps) - 1: | |
| # noise_timestep = timesteps[i + 1] | |
| # init_latents_proper = self.scheduler.scale_noise( | |
| # init_latents_proper, torch.tensor([noise_timestep]), noise | |
| # ) | |
| latents = (1 - init_mask) * init_latents_proper + init_mask * latents | |
| ''' | |
| if latents.dtype != latents_dtype: | |
| if torch.backends.mps.is_available(): | |
| # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 | |
| latents = latents.to(latents_dtype) | |
| if callback_on_step_end is not None: | |
| callback_kwargs = {} | |
| for k in callback_on_step_end_tensor_inputs: | |
| callback_kwargs[k] = locals()[k] | |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
| latents = callback_outputs.pop("latents", latents) | |
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| progress_bar.update() | |
| # if XLA_AVAILABLE: | |
| # xm.mark_step() | |
| # latents = (1 - mask) * image_latents + mask * latents | |
| if output_type == "latent": | |
| image = latents | |
| else: | |
| latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) | |
| latents = latents[:,:,:,:target_width//self.vae_scale_factor] | |
| latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor | |
| image = self.vae.decode(latents.to(device=self.vae.device, dtype=self.vae.dtype), return_dict=False)[0] | |
| image = self.image_processor.postprocess(image, output_type=output_type) | |
| # Offload all models | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (image,) | |
| return FluxPipelineOutput(images=image) | |
| # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents | |
| def flux_pack_latents(latents, batch_size, num_channels_latents, height, width): | |
| latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) | |
| latents = latents.permute(0, 2, 4, 1, 3, 5) | |
| latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) | |
| return latents | |
| # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents | |
| def flux_unpack_latents(latents, height, width, vae_scale_factor): | |
| batch_size, num_patches, channels = latents.shape | |
| # VAE applies 8x compression on images but we must also account for packing which requires | |
| # latent height and width to be divisible by 2. | |
| height = 2 * (int(height) // (vae_scale_factor * 2)) | |
| width = 2 * (int(width) // (vae_scale_factor * 2)) | |
| latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) | |
| latents = latents.permute(0, 3, 1, 4, 2, 5) | |
| latents = latents.reshape(batch_size, channels // (2 * 2), height, width) | |
| return latents | |
| # TODO: it is more reasonable to have target pe staring at 0 | |
| def prepare_latent_image_ids(height, width_tgt, height_spa, width_spa, height_sub, width_sub, device, dtype): | |
| assert width_spa==0 or width_tgt==width_spa | |
| latent_image_ids = torch.zeros(height, width_tgt, 3, device=device, dtype=dtype) | |
| latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None] # y坐标 | |
| latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width_tgt, device=device)[None, :] # x坐标 | |
| cond_mark = 0 | |
| if width_spa>0: | |
| cond_mark += 1 | |
| condspa_image_ids = torch.zeros(height_spa, width_spa, 3, device=device, dtype=dtype) | |
| condspa_image_ids[..., 0] = cond_mark | |
| condspa_image_ids[..., 1] = condspa_image_ids[..., 1] + torch.arange(height_spa, device=device)[:, None] | |
| condspa_image_ids[..., 2] = condspa_image_ids[..., 2] + torch.arange(width_spa, device=device)[None, :] | |
| condspa_image_ids = condspa_image_ids.reshape(-1, condspa_image_ids.shape[-1]) | |
| if width_sub>0: | |
| cond_mark += 1 | |
| condsub_image_ids = torch.zeros(height_sub, width_sub, 3, device=device, dtype=dtype) | |
| condsub_image_ids[..., 0] = cond_mark | |
| condsub_image_ids[..., 1] = condsub_image_ids[..., 1] + torch.arange(height_sub, device=device)[:, None] | |
| condsub_image_ids[..., 2] = condsub_image_ids[..., 2] + torch.arange(width_sub, device=device)[None, :] + width_tgt | |
| condsub_image_ids = condsub_image_ids.reshape(-1, condsub_image_ids.shape[-1]) | |
| latent_image_ids = latent_image_ids.reshape(-1, latent_image_ids.shape[-1]) | |
| latent_image_ids = torch.cat([latent_image_ids, condspa_image_ids],dim=-2) if width_spa>0 else latent_image_ids | |
| latent_image_ids = torch.cat([latent_image_ids, condsub_image_ids],dim=-2) if width_sub>0 else latent_image_ids | |
| return latent_image_ids | |
| def crop_to_multiple_of_16(img): | |
| width, height = img.size | |
| # Calculate new dimensions that are multiples of 8 | |
| new_width = width - (width % 16) | |
| new_height = height - (height % 16) | |
| # Calculate crop box coordinates | |
| left = (width - new_width) // 2 | |
| top = (height - new_height) // 2 | |
| right = left + new_width | |
| bottom = top + new_height | |
| # Crop the image | |
| cropped_img = img.crop((left, top, right, bottom)) | |
| return cropped_img | |
| def resize_and_pad_to_size(image, target_width, target_height): | |
| # Convert numpy array to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Get original dimensions | |
| orig_width, orig_height = image.size | |
| # Calculate aspect ratios | |
| target_ratio = target_width / target_height | |
| orig_ratio = orig_width / orig_height | |
| # Calculate new dimensions while maintaining aspect ratio | |
| if orig_ratio > target_ratio: | |
| # Image is wider than target ratio - scale by width | |
| new_width = target_width | |
| new_height = int(new_width / orig_ratio) | |
| else: | |
| # Image is taller than target ratio - scale by height | |
| new_height = target_height | |
| new_width = int(new_height * orig_ratio) | |
| # Resize image | |
| resized_image = image.resize((new_width, new_height)) | |
| # Create white background image of target size | |
| padded_image = Image.new('RGB', (target_width, target_height), 'white') | |
| # Calculate padding to center the image | |
| left_padding = (target_width - new_width) // 2 | |
| top_padding = (target_height - new_height) // 2 | |
| # Paste resized image onto padded background | |
| padded_image.paste(resized_image, (left_padding, top_padding)) | |
| return padded_image, left_padding, top_padding, target_width - new_width - left_padding, target_height - new_height - top_padding | |
| def resize_by_height(image, height): | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # image is a PIL image | |
| image = image.resize((int(image.width * height / image.height), height)) | |
| return crop_to_multiple_of_16(image) |