Spaces:
Runtime error
Runtime error
| # Implementation of the SW Guidance method with our enhanced SWD implementation | |
| # See: https://github.com/alobashev/sw-guidance/ for the original implementation | |
| # | |
| # Alexander Lobashev, Maria Larchenko, Dmitry Guskov | |
| # Color Conditional Generation with Sliced Wasserstein Guidance | |
| # https://arxiv.org/abs/2503.19034 | |
| import gc | |
| import os | |
| from typing import Any, Callable, Dict, List, Literal, Optional, Union | |
| import numpy as np | |
| import PIL | |
| import torch | |
| from diffusers import ( | |
| FlowMatchEulerDiscreteScheduler, | |
| StableDiffusion3Pipeline, | |
| ) | |
| from diffusers.image_processor import PipelineImageInput | |
| from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import ( | |
| XLA_AVAILABLE, | |
| StableDiffusion3PipelineOutput, | |
| calculate_shift, | |
| retrieve_timesteps, | |
| ) | |
| from src.loss.vector_swd import VectorSWDLoss | |
| from src.utils.color_space import rgb_to_lab | |
| from src.utils.image import from_torch, write_img | |
| if XLA_AVAILABLE: | |
| from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import xm | |
| def _no_grad_noise(model, *args, **kw): | |
| """Forward pass with grad disabled; result is returned detached.""" | |
| with torch.no_grad(): | |
| return model(*args, **kw, return_dict=False)[0].detach() | |
| # ---------------- explicit pipeline forward call | |
| class SWStableDiffusion3Pipeline(StableDiffusion3Pipeline): | |
| swd: VectorSWDLoss = None | |
| def setup_swd( | |
| self, | |
| num_projections: int = 64, | |
| use_ucv: bool = False, | |
| use_lcv: bool = False, | |
| distance: Literal["l1", "l2"] = "l1", | |
| num_new_candidates: int = 32, | |
| subsampling_factor: int = 1, | |
| sampling_mode: Literal["gaussian", "qmc"] = "qmc", | |
| ): | |
| self.swd = VectorSWDLoss( | |
| num_proj=num_projections, | |
| distance=distance, | |
| use_ucv=use_ucv, | |
| use_lcv=use_lcv, | |
| num_new_candidates=num_new_candidates, | |
| missing_value_method="interpolate", | |
| ess_alpha=-1, | |
| sampling_mode=sampling_mode, | |
| ).to(self.device) | |
| self.subsampling_factor = subsampling_factor | |
| def do_sw_guidance( | |
| self, | |
| sw_steps, | |
| sw_u_lr, | |
| latents, | |
| t, | |
| prompt_embeds, | |
| pooled_prompt_embeds, | |
| pixels_ref, | |
| cur_iter_step, | |
| write_video_animation_path: Optional[str] = None, | |
| ): | |
| if sw_steps == 0: | |
| return latents | |
| if latents.shape[0] != prompt_embeds.shape[0]: | |
| prompt_embeds = prompt_embeds[1].unsqueeze(0) | |
| pooled_prompt_embeds = pooled_prompt_embeds[1].unsqueeze(0) | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timestep = t.expand(latents.shape[0]) | |
| pixels_ref = ( | |
| rgb_to_lab(pixels_ref.unsqueeze(0).clamp(0, 1).permute(0, 3, 1, 2)) | |
| .permute(0, 2, 3, 1) | |
| .contiguous() | |
| ) | |
| csc_scaler = torch.tensor( | |
| [100, 2 * 128, 2 * 128], dtype=torch.bfloat16, device=latents.device | |
| ).view(1, 3, 1) | |
| csc_bias = torch.tensor( | |
| [0, 0.5, 0.5], dtype=torch.bfloat16, device=latents.device | |
| ).view(1, 3, 1) | |
| u = torch.nn.Parameter( | |
| torch.zeros_like(latents, dtype=torch.bfloat16, device=latents.device) | |
| ) | |
| optimizer = torch.optim.Adam([u], lr=sw_u_lr) | |
| for tt in range(sw_steps): | |
| optimizer.zero_grad() | |
| x_hat_t = latents.detach() + u | |
| noise_pred = _no_grad_noise( | |
| self.transformer, | |
| hidden_states=x_hat_t, | |
| timestep=timestep, | |
| encoder_hidden_states=prompt_embeds, | |
| pooled_projections=pooled_prompt_embeds, | |
| joint_attention_kwargs=self.joint_attention_kwargs, | |
| ) | |
| # ------------ Compute x_0 | |
| sigma_t = self.scheduler.sigmas[ | |
| self.scheduler.index_for_timestep(t) | |
| ] # scalar | |
| while sigma_t.ndim < x_hat_t.ndim: | |
| sigma_t = sigma_t.unsqueeze(-1) | |
| sigma_t = sigma_t.to(x_hat_t.dtype).to(latents.device) | |
| x_0 = x_hat_t - sigma_t * noise_pred | |
| # ------------ Compute loss | |
| img_unscaled = self.vae.decode( | |
| (x_0 / self.vae.config.scaling_factor) + self.vae.config.shift_factor, | |
| return_dict=False, | |
| )[0] | |
| image = (img_unscaled * 0.5 + 0.5).clamp(0, 1) | |
| image_matched = ( | |
| rgb_to_lab(image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous() | |
| ) | |
| # reshape to (B, D, N) where D=3, N = H*W | |
| pred_seq = image_matched.view(1, 3, -1) / csc_scaler + csc_bias | |
| ref_seq = pixels_ref.view(1, 3, -1) / csc_scaler + csc_bias | |
| # Apply subsampling if enabled | |
| if self.subsampling_factor > 1: | |
| pred_seq = pred_seq[..., :: self.subsampling_factor] | |
| ref_seq = ref_seq[..., :: self.subsampling_factor] | |
| loss = self.swd(pred=pred_seq, gt=ref_seq, step=tt) | |
| loss.backward() | |
| optimizer.step() | |
| if write_video_animation_path is not None: | |
| frame_idx = cur_iter_step * sw_steps + tt | |
| write_img( | |
| os.path.join(write_video_animation_path, f"{frame_idx:05d}.jpg"), | |
| from_torch(img_unscaled.squeeze(0)), | |
| ) | |
| latents = latents.detach() + u.detach() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return latents | |
| def __call__( | |
| self, | |
| sw_reference: PIL.Image = None, | |
| sw_steps: int = 8, | |
| sw_u_lr: float = 0.05 * 10**3, | |
| num_guided_steps: int = None, | |
| # ----------------------------------- | |
| prompt: Union[str, List[str]] = None, | |
| prompt_2: Optional[Union[str, List[str]]] = None, | |
| prompt_3: Optional[Union[str, List[str]]] = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 28, | |
| sigmas: Optional[List[float]] = None, | |
| guidance_scale: float = 7.0, | |
| cfg_rescale_phi: float = 0.7, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| negative_prompt_2: Optional[Union[str, List[str]]] = None, | |
| negative_prompt_3: Optional[Union[str, List[str]]] = None, | |
| num_images_per_prompt: Optional[int] = 1, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| ip_adapter_image: Optional[PipelineImageInput] = None, | |
| ip_adapter_image_embeds: Optional[torch.Tensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| clip_skip: Optional[int] = None, | |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
| max_sequence_length: int = 256, | |
| skip_guidance_layers: List[int] = None, | |
| skip_layer_guidance_scale: float = 2.8, | |
| skip_layer_guidance_stop: float = 0.2, | |
| skip_layer_guidance_start: float = 0.01, | |
| mu: Optional[float] = None, | |
| write_video_animation_path: Optional[str] = None, | |
| ): | |
| assert self.swd is not None, "SWD not initialized" | |
| height = height or self.default_sample_size * self.vae_scale_factor | |
| width = width or self.default_sample_size * self.vae_scale_factor | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| prompt, | |
| prompt_2, | |
| prompt_3, | |
| height, | |
| width, | |
| negative_prompt=negative_prompt, | |
| negative_prompt_2=negative_prompt_2, | |
| negative_prompt_3=negative_prompt_3, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
| callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, | |
| max_sequence_length=max_sequence_length, | |
| ) | |
| self._guidance_scale = guidance_scale | |
| self._skip_layer_guidance_scale = skip_layer_guidance_scale | |
| self._clip_skip = clip_skip | |
| self._joint_attention_kwargs = joint_attention_kwargs | |
| self._interrupt = False | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| device = self._execution_device | |
| lora_scale = ( | |
| self.joint_attention_kwargs.get("scale", None) | |
| if self.joint_attention_kwargs is not None | |
| else None | |
| ) | |
| ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds, | |
| ) = self.encode_prompt( | |
| prompt=prompt, | |
| prompt_2=prompt_2, | |
| prompt_3=prompt_3, | |
| negative_prompt=negative_prompt, | |
| negative_prompt_2=negative_prompt_2, | |
| negative_prompt_3=negative_prompt_3, | |
| do_classifier_free_guidance=self.do_classifier_free_guidance, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
| device=device, | |
| clip_skip=self.clip_skip, | |
| num_images_per_prompt=num_images_per_prompt, | |
| max_sequence_length=max_sequence_length, | |
| lora_scale=lora_scale, | |
| ) | |
| if self.do_classifier_free_guidance: | |
| if skip_guidance_layers is not None: | |
| original_prompt_embeds = prompt_embeds | |
| original_pooled_prompt_embeds = pooled_prompt_embeds | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
| pooled_prompt_embeds = torch.cat( | |
| [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 | |
| ) | |
| # 4. Prepare latent variables | |
| num_channels_latents = self.transformer.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| # 5. Prepare timesteps | |
| scheduler_kwargs = {} | |
| if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: | |
| _, _, height, width = latents.shape | |
| image_seq_len = (height // self.transformer.config.patch_size) * ( | |
| width // self.transformer.config.patch_size | |
| ) | |
| mu = calculate_shift( | |
| image_seq_len, | |
| self.scheduler.config.get("base_image_seq_len", 256), | |
| self.scheduler.config.get("max_image_seq_len", 4096), | |
| self.scheduler.config.get("base_shift", 0.5), | |
| self.scheduler.config.get("max_shift", 1.16), | |
| ) | |
| scheduler_kwargs["mu"] = mu | |
| elif mu is not None: | |
| scheduler_kwargs["mu"] = mu | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| self.scheduler, | |
| num_inference_steps, | |
| device, | |
| sigmas=sigmas, | |
| **scheduler_kwargs, | |
| ) | |
| num_warmup_steps = max( | |
| len(timesteps) - num_inference_steps * self.scheduler.order, 0 | |
| ) | |
| self._num_timesteps = len(timesteps) | |
| # 6. Prepare image embeddings | |
| if ( | |
| ip_adapter_image is not None and self.is_ip_adapter_active | |
| ) or ip_adapter_image_embeds is not None: | |
| ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( | |
| ip_adapter_image, | |
| ip_adapter_image_embeds, | |
| device, | |
| batch_size * num_images_per_prompt, | |
| self.do_classifier_free_guidance, | |
| ) | |
| if self.joint_attention_kwargs is None: | |
| self._joint_attention_kwargs = { | |
| "ip_adapter_image_embeds": ip_adapter_image_embeds | |
| } | |
| else: | |
| self._joint_attention_kwargs.update( | |
| ip_adapter_image_embeds=ip_adapter_image_embeds | |
| ) | |
| if sw_reference is not None: | |
| # Resize so the reference is maximal width or height of the output image | |
| target_max_size = max(height, width) | |
| reference_max_size = max(sw_reference.width, sw_reference.height) | |
| scale_factor = target_max_size / reference_max_size | |
| sw_reference = sw_reference.resize( | |
| ( | |
| int(sw_reference.width * scale_factor), | |
| int(sw_reference.height * scale_factor), | |
| ) | |
| ) | |
| pixels_ref = ( | |
| torch.Tensor(np.array(sw_reference).astype(np.float32) / 255) | |
| .permute(2, 0, 1) | |
| .to(device) | |
| .to(torch.bfloat16) | |
| ) | |
| # 7. Denoising loop | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| if self.interrupt: | |
| continue | |
| # broadcast to batch dimension in a way that's compatible | |
| # with ONNX/Core ML | |
| timestep = t.expand(latents.shape[0]) | |
| # SW Guidance | |
| if sw_reference is not None: | |
| if num_guided_steps is None or i < num_guided_steps: | |
| latents = self.do_sw_guidance( | |
| sw_steps, | |
| sw_u_lr, | |
| latents, | |
| t, | |
| prompt_embeds, | |
| pooled_prompt_embeds, | |
| pixels_ref, | |
| cur_iter_step=i, | |
| write_video_animation_path=write_video_animation_path, | |
| ) | |
| if i == num_guided_steps // 2: | |
| self.swd.reset() | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = ( | |
| torch.cat([latents] * 2) | |
| if self.do_classifier_free_guidance | |
| else latents | |
| ) | |
| with torch.no_grad(): | |
| timestep = t.expand(latent_model_input.shape[0]) | |
| noise_pred = self.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=prompt_embeds, | |
| pooled_projections=pooled_prompt_embeds, | |
| joint_attention_kwargs=self.joint_attention_kwargs, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| should_skip_layers = ( | |
| True | |
| if i > num_inference_steps * skip_layer_guidance_start | |
| and i < num_inference_steps * skip_layer_guidance_stop | |
| else False | |
| ) | |
| if skip_guidance_layers is not None and should_skip_layers: | |
| timestep = t.expand(latents.shape[0]) | |
| latent_model_input = latents | |
| noise_pred_skip_layers = self.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=original_prompt_embeds, | |
| pooled_projections=original_pooled_prompt_embeds, | |
| joint_attention_kwargs=self.joint_attention_kwargs, | |
| return_dict=False, | |
| skip_layers=skip_guidance_layers, | |
| )[0] | |
| noise_pred = ( | |
| noise_pred | |
| + (noise_pred_text - noise_pred_skip_layers) | |
| * self._skip_layer_guidance_scale | |
| ) | |
| # Based on Sec. 3.4 of Lin, Liu, Li, Yang - | |
| # Common Diffusion Noise Schedules and Sample Steps are Flawed | |
| # https://arxiv.org/abs/2305.08891 | |
| # While Flow matching is free of most issues, a high CFG scale | |
| # can still cause over-exposure issues as discussed in the work. | |
| if cfg_rescale_phi is not None and cfg_rescale_phi > 0: | |
| # σ_pos and σ_cfg are per-sample (B×1×1×1) stdevs | |
| sigma_pos = noise_pred_text.std(dim=(1, 2, 3), keepdim=True) | |
| sigma_cfg = noise_pred.std(dim=(1, 2, 3), keepdim=True) | |
| # Linear blend between the raw ratio and 1, | |
| # cf. Eq. (15–16) in the paper | |
| factor = torch.lerp( | |
| sigma_pos / (sigma_cfg + 1e-8), # avoid div-by-zero | |
| torch.ones_like(sigma_cfg), | |
| 1.0 - cfg_rescale_phi, | |
| ) | |
| noise_pred = noise_pred * factor | |
| else: | |
| noise_pred = noise_pred | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents_dtype = latents.dtype | |
| latents = self.scheduler.step( | |
| noise_pred, t, latents, return_dict=False | |
| )[0] | |
| if latents.dtype != latents_dtype: | |
| if torch.backends.mps.is_available(): | |
| # some platforms (eg. apple mps) misbehave due to a | |
| # pytorch bug: https://github.com/pytorch/pytorch/pull/99272 | |
| latents = latents.to(latents_dtype) | |
| if callback_on_step_end is not None: | |
| callback_kwargs = {} | |
| for k in callback_on_step_end_tensor_inputs: | |
| callback_kwargs[k] = locals()[k] | |
| callback_outputs = callback_on_step_end( | |
| self, i, t, callback_kwargs | |
| ) | |
| latents = callback_outputs.pop("latents", latents) | |
| prompt_embeds = callback_outputs.pop( | |
| "prompt_embeds", prompt_embeds | |
| ) | |
| negative_prompt_embeds = callback_outputs.pop( | |
| "negative_prompt_embeds", negative_prompt_embeds | |
| ) | |
| negative_pooled_prompt_embeds = callback_outputs.pop( | |
| "negative_pooled_prompt_embeds", | |
| negative_pooled_prompt_embeds, | |
| ) | |
| if write_video_animation_path is not None and i >= num_guided_steps: | |
| with torch.no_grad(): | |
| image = self.vae.decode( | |
| (latents / self.vae.config.scaling_factor) | |
| + self.vae.config.shift_factor, | |
| return_dict=False, | |
| )[0] | |
| cur_frame_idx = i * sw_steps | |
| write_img( | |
| os.path.join( | |
| write_video_animation_path, | |
| f"{cur_frame_idx:05d}.jpg", | |
| ), | |
| from_torch(image.squeeze(0)), | |
| ) | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ( | |
| (i + 1) > num_warmup_steps | |
| and (i + 1) % self.scheduler.order == 0 | |
| ): | |
| progress_bar.update() | |
| if XLA_AVAILABLE: | |
| xm.mark_step() | |
| if output_type == "latent": | |
| image = latents | |
| else: | |
| latents = ( | |
| latents / self.vae.config.scaling_factor | |
| ) + self.vae.config.shift_factor | |
| image = self.vae.decode(latents, return_dict=False)[0] | |
| image = self.image_processor.postprocess( | |
| image.detach(), output_type=output_type | |
| ) | |
| # Offload all models | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (image,) | |
| return StableDiffusion3PipelineOutput(images=image) | |
| def run( | |
| prompt: str, | |
| reference_image: PIL.Image.Image, | |
| model_path: str, | |
| num_inference_steps: int = 30, | |
| num_guided_steps: int = 28, | |
| guidance_scale: float = 5.0, | |
| cfg_rescale_phi: float = 0.7, | |
| sw_u_lr: float = 3e-3, | |
| sw_steps: int = 8, | |
| height: int = 768, | |
| width: int = 768, | |
| device: str = "cuda", | |
| seed: Optional[int] = None, | |
| # Add new SW-related parameters | |
| num_projections: int = 64, | |
| use_ucv: bool = False, | |
| use_lcv: bool = False, | |
| distance: Literal["l1", "l2"] = "l1", | |
| num_new_candidates: int = 32, | |
| subsampling_factor: int = 1, | |
| sampling_mode: Literal["gaussian", "qmc"] = "gaussian", | |
| pipe: Optional[SWStableDiffusion3Pipeline] = None, | |
| compile: bool = False, | |
| video_animation_path: Optional[str] = None, | |
| ) -> PIL.Image.Image: | |
| """ | |
| Generate an image using SW Guidance with a given prompt and reference image. | |
| Args: | |
| prompt (str): Text prompt to guide the generation | |
| reference_image (PIL.Image.Image): Reference image to guide the generation | |
| model_path (str): Path to the model weights | |
| num_inference_steps (int): Number of denoising steps | |
| num_guided_steps (int): Number of steps to apply SW guidance | |
| guidance_scale (float): Scale for classifier-free guidance | |
| cfg_rescale_phi (float): Rescale factor for classifier-free guidance | |
| sw_u_lr (float): Learning rate for SW guidance | |
| sw_steps (int): Number of steps to apply SW guidance | |
| height (int): Output image height | |
| width (int): Output image width | |
| device (str): Device to run the model on | |
| num_projections (int): Number of random projections for VectorSWDLoss | |
| use_ucv (bool): Use UCV variant of VectorSWDLoss | |
| use_lcv (bool): Use LCV variant of VectorSWDLoss | |
| distance (str): Distance metric for VectorSWDLoss ("l1" or "l2") | |
| refresh_projections_every_n_steps (int): How often to refresh projections | |
| num_new_candidates (int): Number of new candidates for the reservoir | |
| subsampling_factor (int): Factor to subsample points for SW computation. | |
| Higher values reduce memory usage but may affect quality. | |
| sampling_mode (str): Sampling mode for VectorSWDLoss. | |
| pipe (SWStableDiffusion3Pipeline): Pipeline to use for generation. | |
| If None, a new pipeline is created. | |
| compile (bool): Whether to compile the pipeline. | |
| Returns: | |
| PIL.Image.Image: Generated image | |
| """ | |
| # Normalize device to torch.device for robustness | |
| device = torch.device(device) if not isinstance(device, torch.device) else device | |
| if pipe is None: | |
| pipe = create_pipeline(model_path, device, compile=compile) | |
| pipe.setup_swd( | |
| num_projections=num_projections, | |
| use_ucv=use_ucv, | |
| use_lcv=use_lcv, | |
| distance=distance, | |
| num_new_candidates=num_new_candidates, | |
| subsampling_factor=subsampling_factor, | |
| sampling_mode=sampling_mode, | |
| ) | |
| if seed is not None: | |
| print(f"Using seed: {seed}") | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| else: | |
| generator = None | |
| image = pipe( | |
| prompt=prompt, | |
| num_inference_steps=num_inference_steps, | |
| num_guided_steps=num_guided_steps, | |
| guidance_scale=guidance_scale, | |
| cfg_rescale_phi=cfg_rescale_phi, | |
| sw_u_lr=sw_u_lr, | |
| sw_steps=sw_steps, | |
| height=height, | |
| width=width, | |
| sw_reference=reference_image, | |
| generator=generator, | |
| write_video_animation_path=video_animation_path, | |
| ).images[0] | |
| return image | |
| def create_pipeline(model_path, device: str = "cuda", compile: bool = False): | |
| pipe = SWStableDiffusion3Pipeline.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
| pipe.to(device) | |
| if compile: | |
| pipe.transformer = torch.compile(pipe.transformer) | |
| pipe.vae.decoder = torch.compile(pipe.vae.decoder) | |
| return pipe | |