File size: 20,365 Bytes
3d1765e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 |
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
SiglipImageProcessor,
SiglipVisionModel,
)
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.transformers import SD3Transformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from pipeline_stable_diffusion_3_S import StableDiffusion3SPipeline
>>> pipe = StableDiffusion3SPipeline.from_pretrained(
... "./sd3-finetuned-S", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> prompt = "A beautiful woman with a gun in her hand, wearing a bikini."
>>> image = pipe(prompt).images[0]
>>> image.save("sd3_S_version.png")
```
"""
def calculate_shift(
image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
def retrieve_timesteps(
scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs,
):
if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
if timesteps is not None:
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class StableDiffusion3SPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
_optional_components = ["text_encoder_3", "tokenizer_3", "image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"]
def __init__(
self,
transformer: SD3Transformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL,
text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer,
text_encoder_2: CLIPTextModelWithProjection, tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast,
image_encoder: SiglipVisionModel = None, feature_extractor: SiglipImageProcessor = None,
):
super().__init__()
self.register_modules(
vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, text_encoder_3=text_encoder_3,
tokenizer=tokenizer, tokenizer_2=tokenizer_2, tokenizer_3=tokenizer_3,
transformer=transformer, scheduler=scheduler,
image_encoder=image_encoder, feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
self.tokenizer_3_max_length = 256
self.default_sample_size = self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None else 128
self.patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
def _get_clip_prompt_embeds(self, prompt, num_images_per_prompt=1, device=None, clip_skip=None, clip_model_index=0):
device = device or self._execution_device
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
tokenizer, text_encoder = clip_tokenizers[clip_model_index], clip_text_encoders[clip_model_index]
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = tokenizer(prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_tensors="pt")
prompt_embeds_output = text_encoder(text_inputs.input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds_output[0]
if clip_skip is None: prompt_embeds = prompt_embeds_output.hidden_states[-2]
else: prompt_embeds = prompt_embeds_output.hidden_states[-(clip_skip + 2)]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(len(prompt) * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(len(prompt) * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
def encode_prompt(
self, prompt, prompt_2=None, prompt_3=None, device=None, num_images_per_prompt=1, do_classifier_free_guidance=True,
negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None, prompt_embeds=None, negative_prompt_embeds=None,
pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, clip_skip=None, lora_scale=None,
):
device = device or self._execution_device
# [แก้ไข] เพิ่มการ scale LoRA ให้กับ Transformer
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
self._lora_scale = lora_scale
if self.transformer and USE_PEFT_BACKEND: scale_lora_layers(self.transformer, lora_scale)
if self.text_encoder and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): batch_size = 1
elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt)
else: batch_size = prompt_embeds.shape[0]
target_dim = self.transformer.config.joint_attention_dim
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_3 = prompt_3 or prompt
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=0)
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(prompt=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=1)
clip_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
text_inputs_3 = self.tokenizer_3(prompt_3, padding="max_length", max_length=self.tokenizer_3_max_length, truncation=True, return_tensors="pt")
prompt_embeds_out_3 = self.text_encoder_3(text_inputs_3.input_ids.to(device))
t5_embeds = prompt_embeds_out_3.last_hidden_state.repeat_interleave(num_images_per_prompt, dim=0)
if clip_embeds.shape[-1] < target_dim: clip_embeds = torch.nn.functional.pad(clip_embeds, (0, target_dim - clip_embeds.shape[-1]))
if t5_embeds.shape[-1] < target_dim: t5_embeds = torch.nn.functional.pad(t5_embeds, (0, target_dim - t5_embeds.shape[-1]))
prompt_embeds = torch.cat([clip_embeds, t5_embeds], dim=1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
negative_prompt_3 = negative_prompt_3 or negative_prompt
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
negative_prompt_3 = batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
neg_prompt_embed, neg_pooled_prompt_embed = self._get_clip_prompt_embeds(negative_prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=0)
neg_prompt_2_embed, neg_pooled_prompt_2_embed = self._get_clip_prompt_embeds(negative_prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=1)
neg_clip_embeds = torch.cat([neg_prompt_embed, neg_prompt_2_embed], dim=-1)
negative_pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embed, neg_pooled_prompt_2_embed], dim=-1)
neg_text_inputs_3 = self.tokenizer_3(negative_prompt_3, padding="max_length", max_length=self.tokenizer_3_max_length, truncation=True, return_tensors="pt")
neg_prompt_embeds_out_3 = self.text_encoder_3(neg_text_inputs_3.input_ids.to(device))
neg_t5_embeds = neg_prompt_embeds_out_3.last_hidden_state.repeat_interleave(num_images_per_prompt, dim=0)
if neg_clip_embeds.shape[-1] < target_dim: neg_clip_embeds = torch.nn.functional.pad(neg_clip_embeds, (0, target_dim - neg_clip_embeds.shape[-1]))
if neg_t5_embeds.shape[-1] < target_dim: neg_t5_embeds = torch.nn.functional.pad(neg_t5_embeds, (0, target_dim - neg_t5_embeds.shape[-1]))
negative_prompt_embeds = torch.cat([neg_clip_embeds, neg_t5_embeds], dim=1)
# [แก้ไข] เพิ่มการ unscale LoRA ให้กับ Transformer
if self.transformer and isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: unscale_lora_layers(self.transformer, lora_scale)
if self.text_encoder and isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 and isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: unscale_lora_layers(self.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
def check_inputs(self, prompt, prompt_2, prompt_3, height, width, negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None,
prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None):
if height % (self.vae_scale_factor * self.patch_size) != 0 or width % (self.vae_scale_factor * self.patch_size) != 0: raise ValueError(f"`height` and `width` must be divisible by {self.vae_scale_factor * self.patch_size}")
if callback_on_step_end_tensor_inputs is not None and not all(k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs): raise ValueError("`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}")
if prompt is not None and prompt_embeds is not None: raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.")
if prompt is None and prompt_embeds is None: raise ValueError("Provide either `prompt` or `prompt_embeds`.")
if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError("Cannot forward both `negative_prompt` and `negative_prompt_embeds`.")
if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError("If `prompt_embeds` are provided, `pooled_prompt_embeds` must also be passed.")
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError("If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` must also be passed.")
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
if latents is not None: return latents.to(device=device, dtype=dtype)
shape = (batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def guidance_scale(self): return self._guidance_scale
@property
def clip_skip(self): return self._clip_skip
@property
def do_classifier_free_guidance(self): return self._guidance_scale > 1
@property
def joint_attention_kwargs(self): return self._joint_attention_kwargs
@property
def num_timesteps(self): return self._num_timesteps
@property
def interrupt(self): return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28,
sigmas: Optional[List[float]] = None, guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil",
return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], mu: Optional[float] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Examples:
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
self.check_inputs(
prompt, prompt_2, prompt_3, height, width, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
batch_size = 1 if isinstance(prompt, str) else len(prompt) if prompt is not None else prompt_embeds.shape[0]
device = self._execution_device
(prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) = self.encode_prompt(
prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3,
negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3,
do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt, lora_scale=lora_scale,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt, num_channels_latents, height, width,
prompt_embeds.dtype, device, generator, latents,
)
scheduler_kwargs = {}
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
_, _, h, w = latents.shape
image_seq_len = (h // self.transformer.config.patch_size) * (w // self.transformer.config.patch_size)
mu = calculate_shift(image_seq_len)
scheduler_kwargs["mu"] = mu
elif mu is not None:
scheduler_kwargs["mu"] = mu
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt: continue
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent": image = latents
else:
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
self.maybe_free_model_hooks()
if not return_dict: return (image,)
return StableDiffusion3PipelineOutput(images=image) |