| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import gc |
| | import os |
| | from typing import Any, Callable, Dict, List, Literal, Optional, Union |
| |
|
| | import numpy as np |
| | import PIL |
| | import torch |
| | from diffusers import ( |
| | FlowMatchEulerDiscreteScheduler, |
| | StableDiffusion3Pipeline, |
| | ) |
| | from diffusers.image_processor import PipelineImageInput |
| | from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import ( |
| | XLA_AVAILABLE, |
| | StableDiffusion3PipelineOutput, |
| | calculate_shift, |
| | retrieve_timesteps, |
| | ) |
| |
|
| | from src.loss.vector_swd import VectorSWDLoss |
| | from src.utils.color_space import rgb_to_lab |
| | from src.utils.image import from_torch, write_img |
| |
|
| | if XLA_AVAILABLE: |
| | from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import xm |
| |
|
| |
|
| | def _no_grad_noise(model, *args, **kw): |
| | """Forward pass with grad disabled; result is returned detached.""" |
| | with torch.no_grad(): |
| | return model(*args, **kw, return_dict=False)[0].detach() |
| |
|
| |
|
| | |
| | class SWStableDiffusion3Pipeline(StableDiffusion3Pipeline): |
| | swd: VectorSWDLoss = None |
| |
|
| | def setup_swd( |
| | self, |
| | num_projections: int = 64, |
| | use_ucv: bool = False, |
| | use_lcv: bool = False, |
| | distance: Literal["l1", "l2"] = "l1", |
| | num_new_candidates: int = 32, |
| | subsampling_factor: int = 1, |
| | sampling_mode: Literal["gaussian", "qmc"] = "qmc", |
| | ): |
| | self.swd = VectorSWDLoss( |
| | num_proj=num_projections, |
| | distance=distance, |
| | use_ucv=use_ucv, |
| | use_lcv=use_lcv, |
| | num_new_candidates=num_new_candidates, |
| | missing_value_method="interpolate", |
| | ess_alpha=-1, |
| | sampling_mode=sampling_mode, |
| | ).to(self.device) |
| | self.subsampling_factor = subsampling_factor |
| |
|
| | def do_sw_guidance( |
| | self, |
| | sw_steps, |
| | sw_u_lr, |
| | latents, |
| | t, |
| | prompt_embeds, |
| | pooled_prompt_embeds, |
| | pixels_ref, |
| | cur_iter_step, |
| | write_video_animation_path: Optional[str] = None, |
| | ): |
| | if sw_steps == 0: |
| | return latents |
| |
|
| | if latents.shape[0] != prompt_embeds.shape[0]: |
| | prompt_embeds = prompt_embeds[1].unsqueeze(0) |
| | pooled_prompt_embeds = pooled_prompt_embeds[1].unsqueeze(0) |
| |
|
| | |
| | timestep = t.expand(latents.shape[0]) |
| |
|
| | pixels_ref = ( |
| | rgb_to_lab(pixels_ref.unsqueeze(0).clamp(0, 1).permute(0, 3, 1, 2)) |
| | .permute(0, 2, 3, 1) |
| | .contiguous() |
| | ) |
| |
|
| | csc_scaler = torch.tensor( |
| | [100, 2 * 128, 2 * 128], dtype=torch.bfloat16, device=latents.device |
| | ).view(1, 3, 1) |
| | csc_bias = torch.tensor( |
| | [0, 0.5, 0.5], dtype=torch.bfloat16, device=latents.device |
| | ).view(1, 3, 1) |
| |
|
| | u = torch.nn.Parameter( |
| | torch.zeros_like(latents, dtype=torch.bfloat16, device=latents.device) |
| | ) |
| | optimizer = torch.optim.Adam([u], lr=sw_u_lr) |
| |
|
| | for tt in range(sw_steps): |
| | optimizer.zero_grad() |
| |
|
| | x_hat_t = latents.detach() + u |
| | noise_pred = _no_grad_noise( |
| | self.transformer, |
| | hidden_states=x_hat_t, |
| | timestep=timestep, |
| | encoder_hidden_states=prompt_embeds, |
| | pooled_projections=pooled_prompt_embeds, |
| | joint_attention_kwargs=self.joint_attention_kwargs, |
| | ) |
| |
|
| | |
| | sigma_t = self.scheduler.sigmas[ |
| | self.scheduler.index_for_timestep(t) |
| | ] |
| | while sigma_t.ndim < x_hat_t.ndim: |
| | sigma_t = sigma_t.unsqueeze(-1) |
| | sigma_t = sigma_t.to(x_hat_t.dtype).to(latents.device) |
| |
|
| | x_0 = x_hat_t - sigma_t * noise_pred |
| |
|
| | |
| | img_unscaled = self.vae.decode( |
| | (x_0 / self.vae.config.scaling_factor) + self.vae.config.shift_factor, |
| | return_dict=False, |
| | )[0] |
| | image = (img_unscaled * 0.5 + 0.5).clamp(0, 1) |
| | image_matched = ( |
| | rgb_to_lab(image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous() |
| | ) |
| | |
| | pred_seq = image_matched.view(1, 3, -1) / csc_scaler + csc_bias |
| | ref_seq = pixels_ref.view(1, 3, -1) / csc_scaler + csc_bias |
| |
|
| | |
| | if self.subsampling_factor > 1: |
| | pred_seq = pred_seq[..., :: self.subsampling_factor] |
| | ref_seq = ref_seq[..., :: self.subsampling_factor] |
| |
|
| | loss = self.swd(pred=pred_seq, gt=ref_seq, step=tt) |
| |
|
| | loss.backward() |
| | optimizer.step() |
| |
|
| | if write_video_animation_path is not None: |
| | frame_idx = cur_iter_step * sw_steps + tt |
| | write_img( |
| | os.path.join(write_video_animation_path, f"{frame_idx:05d}.jpg"), |
| | from_torch(img_unscaled.squeeze(0)), |
| | ) |
| |
|
| | latents = latents.detach() + u.detach() |
| |
|
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | return latents |
| |
|
| | def __call__( |
| | self, |
| | sw_reference: PIL.Image = None, |
| | sw_steps: int = 8, |
| | sw_u_lr: float = 0.05 * 10**3, |
| | num_guided_steps: int = None, |
| | |
| | prompt: Union[str, List[str]] = None, |
| | prompt_2: Optional[Union[str, List[str]]] = None, |
| | prompt_3: Optional[Union[str, List[str]]] = None, |
| | height: Optional[int] = None, |
| | width: Optional[int] = None, |
| | num_inference_steps: int = 28, |
| | sigmas: Optional[List[float]] = None, |
| | guidance_scale: float = 7.0, |
| | cfg_rescale_phi: float = 0.7, |
| | negative_prompt: Optional[Union[str, List[str]]] = None, |
| | negative_prompt_2: Optional[Union[str, List[str]]] = None, |
| | negative_prompt_3: Optional[Union[str, List[str]]] = None, |
| | num_images_per_prompt: Optional[int] = 1, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | latents: Optional[torch.FloatTensor] = None, |
| | prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | ip_adapter_image: Optional[PipelineImageInput] = None, |
| | ip_adapter_image_embeds: Optional[torch.Tensor] = None, |
| | output_type: Optional[str] = "pil", |
| | return_dict: bool = True, |
| | joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | clip_skip: Optional[int] = None, |
| | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
| | callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| | max_sequence_length: int = 256, |
| | skip_guidance_layers: List[int] = None, |
| | skip_layer_guidance_scale: float = 2.8, |
| | skip_layer_guidance_stop: float = 0.2, |
| | skip_layer_guidance_start: float = 0.01, |
| | mu: Optional[float] = None, |
| | write_video_animation_path: Optional[str] = None, |
| | ): |
| | assert self.swd is not None, "SWD not initialized" |
| |
|
| | height = height or self.default_sample_size * self.vae_scale_factor |
| | width = width or self.default_sample_size * self.vae_scale_factor |
| |
|
| | |
| | self.check_inputs( |
| | prompt, |
| | prompt_2, |
| | prompt_3, |
| | height, |
| | width, |
| | negative_prompt=negative_prompt, |
| | negative_prompt_2=negative_prompt_2, |
| | negative_prompt_3=negative_prompt_3, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | pooled_prompt_embeds=pooled_prompt_embeds, |
| | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
| | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, |
| | max_sequence_length=max_sequence_length, |
| | ) |
| |
|
| | self._guidance_scale = guidance_scale |
| | self._skip_layer_guidance_scale = skip_layer_guidance_scale |
| | self._clip_skip = clip_skip |
| | self._joint_attention_kwargs = joint_attention_kwargs |
| | self._interrupt = False |
| |
|
| | |
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | device = self._execution_device |
| |
|
| | lora_scale = ( |
| | self.joint_attention_kwargs.get("scale", None) |
| | if self.joint_attention_kwargs is not None |
| | else None |
| | ) |
| | ( |
| | prompt_embeds, |
| | negative_prompt_embeds, |
| | pooled_prompt_embeds, |
| | negative_pooled_prompt_embeds, |
| | ) = self.encode_prompt( |
| | prompt=prompt, |
| | prompt_2=prompt_2, |
| | prompt_3=prompt_3, |
| | negative_prompt=negative_prompt, |
| | negative_prompt_2=negative_prompt_2, |
| | negative_prompt_3=negative_prompt_3, |
| | do_classifier_free_guidance=self.do_classifier_free_guidance, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | pooled_prompt_embeds=pooled_prompt_embeds, |
| | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
| | device=device, |
| | clip_skip=self.clip_skip, |
| | num_images_per_prompt=num_images_per_prompt, |
| | max_sequence_length=max_sequence_length, |
| | lora_scale=lora_scale, |
| | ) |
| |
|
| | if self.do_classifier_free_guidance: |
| | if skip_guidance_layers is not None: |
| | original_prompt_embeds = prompt_embeds |
| | original_pooled_prompt_embeds = pooled_prompt_embeds |
| | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
| | pooled_prompt_embeds = torch.cat( |
| | [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 |
| | ) |
| |
|
| | |
| | num_channels_latents = self.transformer.config.in_channels |
| | latents = self.prepare_latents( |
| | batch_size * num_images_per_prompt, |
| | num_channels_latents, |
| | height, |
| | width, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | latents, |
| | ) |
| |
|
| | |
| | scheduler_kwargs = {} |
| | if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: |
| | _, _, height, width = latents.shape |
| | image_seq_len = (height // self.transformer.config.patch_size) * ( |
| | width // self.transformer.config.patch_size |
| | ) |
| | mu = calculate_shift( |
| | image_seq_len, |
| | self.scheduler.config.get("base_image_seq_len", 256), |
| | self.scheduler.config.get("max_image_seq_len", 4096), |
| | self.scheduler.config.get("base_shift", 0.5), |
| | self.scheduler.config.get("max_shift", 1.16), |
| | ) |
| | scheduler_kwargs["mu"] = mu |
| | elif mu is not None: |
| | scheduler_kwargs["mu"] = mu |
| | timesteps, num_inference_steps = retrieve_timesteps( |
| | self.scheduler, |
| | num_inference_steps, |
| | device, |
| | sigmas=sigmas, |
| | **scheduler_kwargs, |
| | ) |
| | num_warmup_steps = max( |
| | len(timesteps) - num_inference_steps * self.scheduler.order, 0 |
| | ) |
| | self._num_timesteps = len(timesteps) |
| |
|
| | |
| | if ( |
| | ip_adapter_image is not None and self.is_ip_adapter_active |
| | ) or ip_adapter_image_embeds is not None: |
| | ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( |
| | ip_adapter_image, |
| | ip_adapter_image_embeds, |
| | device, |
| | batch_size * num_images_per_prompt, |
| | self.do_classifier_free_guidance, |
| | ) |
| |
|
| | if self.joint_attention_kwargs is None: |
| | self._joint_attention_kwargs = { |
| | "ip_adapter_image_embeds": ip_adapter_image_embeds |
| | } |
| | else: |
| | self._joint_attention_kwargs.update( |
| | ip_adapter_image_embeds=ip_adapter_image_embeds |
| | ) |
| |
|
| | if sw_reference is not None: |
| | |
| |
|
| | target_max_size = max(height, width) |
| | reference_max_size = max(sw_reference.width, sw_reference.height) |
| | scale_factor = target_max_size / reference_max_size |
| |
|
| | sw_reference = sw_reference.resize( |
| | ( |
| | int(sw_reference.width * scale_factor), |
| | int(sw_reference.height * scale_factor), |
| | ) |
| | ) |
| | pixels_ref = ( |
| | torch.Tensor(np.array(sw_reference).astype(np.float32) / 255) |
| | .permute(2, 0, 1) |
| | .to(device) |
| | .to(torch.bfloat16) |
| | ) |
| |
|
| | |
| | with self.progress_bar(total=num_inference_steps) as progress_bar: |
| | for i, t in enumerate(timesteps): |
| | if self.interrupt: |
| | continue |
| |
|
| | |
| | |
| | timestep = t.expand(latents.shape[0]) |
| |
|
| | |
| | if sw_reference is not None: |
| | if num_guided_steps is None or i < num_guided_steps: |
| | latents = self.do_sw_guidance( |
| | sw_steps, |
| | sw_u_lr, |
| | latents, |
| | t, |
| | prompt_embeds, |
| | pooled_prompt_embeds, |
| | pixels_ref, |
| | cur_iter_step=i, |
| | write_video_animation_path=write_video_animation_path, |
| | ) |
| | if i == num_guided_steps // 2: |
| | self.swd.reset() |
| |
|
| | |
| | latent_model_input = ( |
| | torch.cat([latents] * 2) |
| | if self.do_classifier_free_guidance |
| | else latents |
| | ) |
| |
|
| | with torch.no_grad(): |
| | timestep = t.expand(latent_model_input.shape[0]) |
| |
|
| | noise_pred = self.transformer( |
| | hidden_states=latent_model_input, |
| | timestep=timestep, |
| | encoder_hidden_states=prompt_embeds, |
| | pooled_projections=pooled_prompt_embeds, |
| | joint_attention_kwargs=self.joint_attention_kwargs, |
| | return_dict=False, |
| | )[0] |
| |
|
| | |
| | if self.do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + self.guidance_scale * ( |
| | noise_pred_text - noise_pred_uncond |
| | ) |
| |
|
| | should_skip_layers = ( |
| | True |
| | if i > num_inference_steps * skip_layer_guidance_start |
| | and i < num_inference_steps * skip_layer_guidance_stop |
| | else False |
| | ) |
| | if skip_guidance_layers is not None and should_skip_layers: |
| | timestep = t.expand(latents.shape[0]) |
| | latent_model_input = latents |
| | noise_pred_skip_layers = self.transformer( |
| | hidden_states=latent_model_input, |
| | timestep=timestep, |
| | encoder_hidden_states=original_prompt_embeds, |
| | pooled_projections=original_pooled_prompt_embeds, |
| | joint_attention_kwargs=self.joint_attention_kwargs, |
| | return_dict=False, |
| | skip_layers=skip_guidance_layers, |
| | )[0] |
| | noise_pred = ( |
| | noise_pred |
| | + (noise_pred_text - noise_pred_skip_layers) |
| | * self._skip_layer_guidance_scale |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | if cfg_rescale_phi is not None and cfg_rescale_phi > 0: |
| | |
| | sigma_pos = noise_pred_text.std(dim=(1, 2, 3), keepdim=True) |
| | sigma_cfg = noise_pred.std(dim=(1, 2, 3), keepdim=True) |
| |
|
| | |
| | |
| | factor = torch.lerp( |
| | sigma_pos / (sigma_cfg + 1e-8), |
| | torch.ones_like(sigma_cfg), |
| | 1.0 - cfg_rescale_phi, |
| | ) |
| | noise_pred = noise_pred * factor |
| | else: |
| | noise_pred = noise_pred |
| |
|
| | |
| | latents_dtype = latents.dtype |
| | latents = self.scheduler.step( |
| | noise_pred, t, latents, return_dict=False |
| | )[0] |
| |
|
| | if latents.dtype != latents_dtype: |
| | if torch.backends.mps.is_available(): |
| | |
| | |
| | latents = latents.to(latents_dtype) |
| |
|
| | if callback_on_step_end is not None: |
| | callback_kwargs = {} |
| | for k in callback_on_step_end_tensor_inputs: |
| | callback_kwargs[k] = locals()[k] |
| | callback_outputs = callback_on_step_end( |
| | self, i, t, callback_kwargs |
| | ) |
| |
|
| | latents = callback_outputs.pop("latents", latents) |
| | prompt_embeds = callback_outputs.pop( |
| | "prompt_embeds", prompt_embeds |
| | ) |
| | negative_prompt_embeds = callback_outputs.pop( |
| | "negative_prompt_embeds", negative_prompt_embeds |
| | ) |
| | negative_pooled_prompt_embeds = callback_outputs.pop( |
| | "negative_pooled_prompt_embeds", |
| | negative_pooled_prompt_embeds, |
| | ) |
| |
|
| | if write_video_animation_path is not None and i >= num_guided_steps: |
| | with torch.no_grad(): |
| | image = self.vae.decode( |
| | (latents / self.vae.config.scaling_factor) |
| | + self.vae.config.shift_factor, |
| | return_dict=False, |
| | )[0] |
| | cur_frame_idx = i * sw_steps |
| | write_img( |
| | os.path.join( |
| | write_video_animation_path, |
| | f"{cur_frame_idx:05d}.jpg", |
| | ), |
| | from_torch(image.squeeze(0)), |
| | ) |
| |
|
| | |
| | if i == len(timesteps) - 1 or ( |
| | (i + 1) > num_warmup_steps |
| | and (i + 1) % self.scheduler.order == 0 |
| | ): |
| | progress_bar.update() |
| |
|
| | if XLA_AVAILABLE: |
| | xm.mark_step() |
| |
|
| | if output_type == "latent": |
| | image = latents |
| |
|
| | else: |
| | latents = ( |
| | latents / self.vae.config.scaling_factor |
| | ) + self.vae.config.shift_factor |
| |
|
| | image = self.vae.decode(latents, return_dict=False)[0] |
| | image = self.image_processor.postprocess( |
| | image.detach(), output_type=output_type |
| | ) |
| |
|
| | |
| | self.maybe_free_model_hooks() |
| |
|
| | if not return_dict: |
| | return (image,) |
| |
|
| | return StableDiffusion3PipelineOutput(images=image) |
| |
|
| |
|
| | def run( |
| | prompt: str, |
| | reference_image: PIL.Image.Image, |
| | model_path: str, |
| | num_inference_steps: int = 30, |
| | num_guided_steps: int = 28, |
| | guidance_scale: float = 5.0, |
| | cfg_rescale_phi: float = 0.7, |
| | sw_u_lr: float = 3e-3, |
| | sw_steps: int = 8, |
| | height: int = 768, |
| | width: int = 768, |
| | device: str = "cuda", |
| | seed: Optional[int] = None, |
| | |
| | num_projections: int = 64, |
| | use_ucv: bool = False, |
| | use_lcv: bool = False, |
| | distance: Literal["l1", "l2"] = "l1", |
| | num_new_candidates: int = 32, |
| | subsampling_factor: int = 1, |
| | sampling_mode: Literal["gaussian", "qmc"] = "gaussian", |
| | pipe: Optional[SWStableDiffusion3Pipeline] = None, |
| | compile: bool = False, |
| | video_animation_path: Optional[str] = None, |
| | ) -> PIL.Image.Image: |
| | """ |
| | Generate an image using SW Guidance with a given prompt and reference image. |
| | |
| | Args: |
| | prompt (str): Text prompt to guide the generation |
| | reference_image (PIL.Image.Image): Reference image to guide the generation |
| | model_path (str): Path to the model weights |
| | num_inference_steps (int): Number of denoising steps |
| | num_guided_steps (int): Number of steps to apply SW guidance |
| | guidance_scale (float): Scale for classifier-free guidance |
| | cfg_rescale_phi (float): Rescale factor for classifier-free guidance |
| | sw_u_lr (float): Learning rate for SW guidance |
| | sw_steps (int): Number of steps to apply SW guidance |
| | height (int): Output image height |
| | width (int): Output image width |
| | device (str): Device to run the model on |
| | num_projections (int): Number of random projections for VectorSWDLoss |
| | use_ucv (bool): Use UCV variant of VectorSWDLoss |
| | use_lcv (bool): Use LCV variant of VectorSWDLoss |
| | distance (str): Distance metric for VectorSWDLoss ("l1" or "l2") |
| | refresh_projections_every_n_steps (int): How often to refresh projections |
| | num_new_candidates (int): Number of new candidates for the reservoir |
| | subsampling_factor (int): Factor to subsample points for SW computation. |
| | Higher values reduce memory usage but may affect quality. |
| | sampling_mode (str): Sampling mode for VectorSWDLoss. |
| | pipe (SWStableDiffusion3Pipeline): Pipeline to use for generation. |
| | If None, a new pipeline is created. |
| | compile (bool): Whether to compile the pipeline. |
| | |
| | Returns: |
| | PIL.Image.Image: Generated image |
| | """ |
| | |
| | device = torch.device(device) if not isinstance(device, torch.device) else device |
| | if pipe is None: |
| | pipe = create_pipeline(model_path, device, compile=compile) |
| |
|
| | pipe.setup_swd( |
| | num_projections=num_projections, |
| | use_ucv=use_ucv, |
| | use_lcv=use_lcv, |
| | distance=distance, |
| | num_new_candidates=num_new_candidates, |
| | subsampling_factor=subsampling_factor, |
| | sampling_mode=sampling_mode, |
| | ) |
| |
|
| | if seed is not None: |
| | print(f"Using seed: {seed}") |
| | generator = torch.Generator(device=device).manual_seed(seed) |
| | else: |
| | generator = None |
| |
|
| | image = pipe( |
| | prompt=prompt, |
| | num_inference_steps=num_inference_steps, |
| | num_guided_steps=num_guided_steps, |
| | guidance_scale=guidance_scale, |
| | cfg_rescale_phi=cfg_rescale_phi, |
| | sw_u_lr=sw_u_lr, |
| | sw_steps=sw_steps, |
| | height=height, |
| | width=width, |
| | sw_reference=reference_image, |
| | generator=generator, |
| | write_video_animation_path=video_animation_path, |
| | ).images[0] |
| |
|
| | return image |
| |
|
| |
|
| | def create_pipeline(model_path, device: str = "cuda", compile: bool = False): |
| | pipe = SWStableDiffusion3Pipeline.from_pretrained( |
| | model_path, |
| | torch_dtype=torch.bfloat16, |
| | ) |
| | pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) |
| | pipe.to(device) |
| | if compile: |
| | pipe.transformer = torch.compile(pipe.transformer) |
| | pipe.vae.decoder = torch.compile(pipe.vae.decoder) |
| | return pipe |
| |
|