manbeast3b commited on
Commit
36a77a0
·
verified ·
1 Parent(s): 18d4e6f

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +815 -17
src/pipeline.py CHANGED
@@ -27,6 +27,36 @@ import torch.nn as nn
27
  import torch.nn.functional as F
28
  from torchao.quantization import quantize_, float8_weight_only, int8_dynamic_activation_int4_weight
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # preconfigs
31
  import os
32
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
@@ -41,24 +71,783 @@ Pipeline = None
41
  ckpt_id = "manbeast3b/flux.1-schnell-full1"
42
  ckpt_revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146"
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def load_pipeline() -> Pipeline:
45
  model_name = "manbeast3b/Flux.1.Schnell-full-quant1"
46
  revision = "e7ddf488a4ea8a3cba05db5b8d06e7e0feb826a2"
47
-
48
- # text_encoder_2 = T5EncoderModel.from_pretrained(
49
- # model_name,
50
- # revision=text_enc_revision,
51
- # subfolder="text_encoder_2",
52
- # torch_dtype=torch.bfloat16
53
- # ).to(memory_format=torch.channels_last)
54
-
55
- # vae = AutoencoderKL.from_pretrained(
56
- # ckpt_id,
57
- # revision=ckpt_revision,
58
- # subfolder="vae",
59
- # local_files_only=True,
60
- # torch_dtype=torch.bfloat16
61
- # ).to(memory_format=torch.channels_last)
62
 
63
  hub_model_dir = os.path.join(
64
  HF_HUB_CACHE,
@@ -83,8 +872,17 @@ def load_pipeline() -> Pipeline:
83
  )
84
  # pipeline.vae = torch.compile(vae)
85
  pipeline.to("cuda")
86
- pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
87
- quantize_(pipeline.vae, int8_dynamic_activation_int4_weight())
 
 
 
 
 
 
 
 
 
88
 
89
  warmup_ = "controllable varied focus thai warriors entertainment blue golden pink soft tough padthai"
90
  for _ in range(1):
 
27
  import torch.nn.functional as F
28
  from torchao.quantization import quantize_, float8_weight_only, int8_dynamic_activation_int4_weight
29
 
30
+
31
+ import inspect
32
+ from typing import Any, Callable, Dict, List, Optional, Union
33
+ import numpy as np
34
+ import torch
35
+ from transformers import (
36
+ CLIPImageProcessor,
37
+ CLIPTextModel,
38
+ CLIPTokenizer,
39
+ CLIPVisionModelWithProjection,
40
+ T5EncoderModel,
41
+ T5TokenizerFast,
42
+ )
43
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
44
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
45
+ from diffusers.models import AutoencoderKL, FluxTransformer2DModel
46
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
47
+ from diffusers.utils import (
48
+ USE_PEFT_BACKEND,
49
+ is_torch_xla_available,
50
+ logging,
51
+ replace_example_docstring,
52
+ scale_lora_layers,
53
+ unscale_lora_layers,
54
+ )
55
+ from diffusers.utils.torch_utils import randn_tensor
56
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
57
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
58
+
59
+
60
  # preconfigs
61
  import os
62
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
 
71
  ckpt_id = "manbeast3b/flux.1-schnell-full1"
72
  ckpt_revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146"
73
 
74
+
75
+
76
+
77
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
78
+
79
+ def calculate_shift(
80
+ image_seq_len,
81
+ base_seq_len: int = 256,
82
+ max_seq_len: int = 4096,
83
+ base_shift: float = 0.5,
84
+ max_shift: float = 1.16,
85
+ ):
86
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
87
+ b = base_shift - m * base_seq_len
88
+ mu = image_seq_len * m + b
89
+ return mu
90
+
91
+ def retrieve_timesteps(
92
+ scheduler,
93
+ num_inference_steps: Optional[int] = None,
94
+ device: Optional[Union[str, torch.device]] = None,
95
+ timesteps: Optional[List[int]] = None,
96
+ sigmas: Optional[List[float]] = None,
97
+ **kwargs,
98
+ ):
99
+ if timesteps is not None and sigmas is not None:
100
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
101
+ if timesteps is not None:
102
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
103
+ if not accepts_timesteps:
104
+ raise ValueError(
105
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
106
+ f" timestep schedules. Please check whether you are using the correct scheduler."
107
+ )
108
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
109
+ timesteps = scheduler.timesteps
110
+ num_inference_steps = len(timesteps)
111
+ elif sigmas is not None:
112
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
113
+ if not accept_sigmas:
114
+ raise ValueError(
115
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
116
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
117
+ )
118
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
119
+ timesteps = scheduler.timesteps
120
+ num_inference_steps = len(timesteps)
121
+ else:
122
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
123
+ timesteps = scheduler.timesteps
124
+ return timesteps, num_inference_steps
125
+
126
+
127
+ class FluxPipeline(
128
+ DiffusionPipeline,
129
+ FluxLoraLoaderMixin,
130
+ FromSingleFileMixin,
131
+ TextualInversionLoaderMixin,
132
+ FluxIPAdapterMixin,
133
+ ):
134
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
135
+ _optional_components = ["image_encoder", "feature_extractor"]
136
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
137
+
138
+ def __init__(
139
+ self,
140
+ scheduler: FlowMatchEulerDiscreteScheduler,
141
+ vae: AutoencoderKL,
142
+ text_encoder: CLIPTextModel,
143
+ tokenizer: CLIPTokenizer,
144
+ text_encoder_2: T5EncoderModel,
145
+ tokenizer_2: T5TokenizerFast,
146
+ transformer: FluxTransformer2DModel,
147
+ image_encoder: CLIPVisionModelWithProjection = None,
148
+ feature_extractor: CLIPImageProcessor = None,
149
+ ):
150
+ super().__init__()
151
+
152
+ self.register_modules(
153
+ vae=vae,
154
+ text_encoder=text_encoder,
155
+ text_encoder_2=text_encoder_2,
156
+ tokenizer=tokenizer,
157
+ tokenizer_2=tokenizer_2,
158
+ transformer=transformer,
159
+ scheduler=scheduler,
160
+ image_encoder=image_encoder,
161
+ feature_extractor=feature_extractor,
162
+ )
163
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
164
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
165
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
166
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
167
+ self.tokenizer_max_length = (
168
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
169
+ )
170
+ self.default_sample_size = 128
171
+
172
+ def _get_t5_prompt_embeds(
173
+ self,
174
+ prompt: Union[str, List[str]] = None,
175
+ num_images_per_prompt: int = 1,
176
+ max_sequence_length: int = 512,
177
+ device: Optional[torch.device] = None,
178
+ dtype: Optional[torch.dtype] = None,
179
+ ):
180
+ device = device or self._execution_device
181
+ dtype = dtype or self.text_encoder.dtype
182
+
183
+ prompt = [prompt] if isinstance(prompt, str) else prompt
184
+ batch_size = len(prompt)
185
+
186
+ if isinstance(self, TextualInversionLoaderMixin):
187
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
188
+
189
+ text_inputs = self.tokenizer_2(
190
+ prompt,
191
+ padding="max_length",
192
+ max_length=max_sequence_length,
193
+ truncation=True,
194
+ return_length=False,
195
+ return_overflowing_tokens=False,
196
+ return_tensors="pt",
197
+ )
198
+ text_input_ids = text_inputs.input_ids
199
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
200
+
201
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
202
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
203
+ logger.warning(
204
+ "The following part of your input was truncated because `max_sequence_length` is set to "
205
+ f" {max_sequence_length} tokens: {removed_text}"
206
+ )
207
+
208
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
209
+
210
+ dtype = self.text_encoder_2.dtype
211
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
212
+
213
+ _, seq_len, _ = prompt_embeds.shape
214
+
215
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
216
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
217
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
218
+
219
+ return prompt_embeds
220
+
221
+ def _get_clip_prompt_embeds(
222
+ self,
223
+ prompt: Union[str, List[str]],
224
+ num_images_per_prompt: int = 1,
225
+ device: Optional[torch.device] = None,
226
+ ):
227
+ device = device or self._execution_device
228
+
229
+ prompt = [prompt] if isinstance(prompt, str) else prompt
230
+ batch_size = len(prompt)
231
+
232
+ if isinstance(self, TextualInversionLoaderMixin):
233
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
234
+
235
+ text_inputs = self.tokenizer(
236
+ prompt,
237
+ padding="max_length",
238
+ max_length=self.tokenizer_max_length,
239
+ truncation=True,
240
+ return_overflowing_tokens=False,
241
+ return_length=False,
242
+ return_tensors="pt",
243
+ )
244
+
245
+ text_input_ids = text_inputs.input_ids
246
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
247
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
248
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
249
+ logger.warning(
250
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
251
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
252
+ )
253
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
254
+
255
+ # Use pooled output of CLIPTextModel
256
+ prompt_embeds = prompt_embeds.pooler_output
257
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
258
+
259
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
260
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
261
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
262
+
263
+ return prompt_embeds
264
+
265
+ def encode_prompt(
266
+ self,
267
+ prompt: Union[str, List[str]],
268
+ prompt_2: Union[str, List[str]],
269
+ device: Optional[torch.device] = None,
270
+ num_images_per_prompt: int = 1,
271
+ prompt_embeds: Optional[torch.FloatTensor] = None,
272
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
273
+ max_sequence_length: int = 512,
274
+ lora_scale: Optional[float] = None,
275
+ ):
276
+ device = device or self._execution_device
277
+
278
+ # set lora scale so that monkey patched LoRA
279
+ # function of text encoder can correctly access it
280
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
281
+ self._lora_scale = lora_scale
282
+
283
+ # dynamically adjust the LoRA scale
284
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
285
+ scale_lora_layers(self.text_encoder, lora_scale)
286
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
287
+ scale_lora_layers(self.text_encoder_2, lora_scale)
288
+
289
+ prompt = [prompt] if isinstance(prompt, str) else prompt
290
+
291
+ if prompt_embeds is None:
292
+ prompt_2 = prompt_2 or prompt
293
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
294
+
295
+ # We only use the pooled prompt output from the CLIPTextModel
296
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
297
+ prompt=prompt,
298
+ device=device,
299
+ num_images_per_prompt=num_images_per_prompt,
300
+ )
301
+ prompt_embeds = self._get_t5_prompt_embeds(
302
+ prompt=prompt_2,
303
+ num_images_per_prompt=num_images_per_prompt,
304
+ max_sequence_length=max_sequence_length,
305
+ device=device,
306
+ )
307
+
308
+ if self.text_encoder is not None:
309
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
310
+ # Retrieve the original scale by scaling back the LoRA layers
311
+ unscale_lora_layers(self.text_encoder, lora_scale)
312
+
313
+ if self.text_encoder_2 is not None:
314
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
315
+ # Retrieve the original scale by scaling back the LoRA layers
316
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
317
+
318
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
319
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
320
+
321
+ return prompt_embeds, pooled_prompt_embeds, text_ids
322
+
323
+ def encode_image(self, image, device, num_images_per_prompt):
324
+ dtype = next(self.image_encoder.parameters()).dtype
325
+
326
+ if not isinstance(image, torch.Tensor):
327
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
328
+
329
+ image = image.to(device=device, dtype=dtype)
330
+ image_embeds = self.image_encoder(image).image_embeds
331
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
332
+ return image_embeds
333
+
334
+ def prepare_ip_adapter_image_embeds(
335
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
336
+ ):
337
+ image_embeds = []
338
+ if ip_adapter_image_embeds is None:
339
+ if not isinstance(ip_adapter_image, list):
340
+ ip_adapter_image = [ip_adapter_image]
341
+
342
+ if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
343
+ raise ValueError(
344
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
345
+ )
346
+
347
+ for single_ip_adapter_image, image_proj_layer in zip(
348
+ ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
349
+ ):
350
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
351
+
352
+ image_embeds.append(single_image_embeds[None, :])
353
+ else:
354
+ for single_image_embeds in ip_adapter_image_embeds:
355
+ image_embeds.append(single_image_embeds)
356
+
357
+ ip_adapter_image_embeds = []
358
+ for i, single_image_embeds in enumerate(image_embeds):
359
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
360
+ single_image_embeds = single_image_embeds.to(device=device)
361
+ ip_adapter_image_embeds.append(single_image_embeds)
362
+
363
+ return ip_adapter_image_embeds
364
+
365
+ def check_inputs(
366
+ self,
367
+ prompt,
368
+ prompt_2,
369
+ height,
370
+ width,
371
+ negative_prompt=None,
372
+ negative_prompt_2=None,
373
+ prompt_embeds=None,
374
+ negative_prompt_embeds=None,
375
+ pooled_prompt_embeds=None,
376
+ negative_pooled_prompt_embeds=None,
377
+ callback_on_step_end_tensor_inputs=None,
378
+ max_sequence_length=None,
379
+ ):
380
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
381
+ logger.warning(
382
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
383
+ )
384
+
385
+ if callback_on_step_end_tensor_inputs is not None and not all(
386
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
387
+ ):
388
+ raise ValueError(
389
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
390
+ )
391
+
392
+ if prompt is not None and prompt_embeds is not None:
393
+ raise ValueError(
394
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
395
+ " only forward one of the two."
396
+ )
397
+ elif prompt_2 is not None and prompt_embeds is not None:
398
+ raise ValueError(
399
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
400
+ " only forward one of the two."
401
+ )
402
+ elif prompt is None and prompt_embeds is None:
403
+ raise ValueError(
404
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
405
+ )
406
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
407
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
408
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
409
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
410
+
411
+ if negative_prompt is not None and negative_prompt_embeds is not None:
412
+ raise ValueError(
413
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
414
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
415
+ )
416
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
417
+ raise ValueError(
418
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
419
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
420
+ )
421
+
422
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
423
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
424
+ raise ValueError(
425
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
426
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
427
+ f" {negative_prompt_embeds.shape}."
428
+ )
429
+
430
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
431
+ raise ValueError(
432
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
433
+ )
434
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
435
+ raise ValueError(
436
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
437
+ )
438
+
439
+ if max_sequence_length is not None and max_sequence_length > 512:
440
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
441
+
442
+ @staticmethod
443
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
444
+ latent_image_ids = torch.zeros(height, width, 3)
445
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
446
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
447
+
448
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
449
+
450
+ latent_image_ids = latent_image_ids.reshape(
451
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
452
+ )
453
+
454
+ return latent_image_ids.to(device=device, dtype=dtype)
455
+
456
+ @staticmethod
457
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
458
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
459
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
460
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
461
+
462
+ return latents
463
+
464
+ @staticmethod
465
+ def _unpack_latents(latents, height, width, vae_scale_factor):
466
+ batch_size, num_patches, channels = latents.shape
467
+
468
+ # VAE applies 8x compression on images but we must also account for packing which requires
469
+ # latent height and width to be divisible by 2.
470
+ height = 2 * (int(height) // (vae_scale_factor * 2))
471
+ width = 2 * (int(width) // (vae_scale_factor * 2))
472
+
473
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
474
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
475
+
476
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
477
+
478
+ return latents
479
+
480
+ def enable_vae_slicing(self):
481
+ self.vae.enable_slicing()
482
+
483
+ def disable_vae_slicing(self):
484
+ self.vae.disable_slicing()
485
+
486
+ def enable_vae_tiling(self):
487
+ self.vae.enable_tiling()
488
+
489
+ def disable_vae_tiling(self):
490
+ self.vae.disable_tiling()
491
+
492
+ def prepare_latents(
493
+ self,
494
+ batch_size,
495
+ num_channels_latents,
496
+ height,
497
+ width,
498
+ dtype,
499
+ device,
500
+ generator,
501
+ latents=None,
502
+ ):
503
+ # VAE applies 8x compression on images but we must also account for packing which requires
504
+ # latent height and width to be divisible by 2.
505
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
506
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
507
+
508
+ shape = (batch_size, num_channels_latents, height, width)
509
+
510
+ if latents is not None:
511
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
512
+ return latents.to(device=device, dtype=dtype), latent_image_ids
513
+
514
+ if isinstance(generator, list) and len(generator) != batch_size:
515
+ raise ValueError(
516
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
517
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
518
+ )
519
+
520
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
521
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
522
+
523
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
524
+
525
+ return latents, latent_image_ids
526
+
527
+ @property
528
+ def guidance_scale(self):
529
+ return self._guidance_scale
530
+
531
+ @property
532
+ def joint_attention_kwargs(self):
533
+ return self._joint_attention_kwargs
534
+
535
+ @property
536
+ def num_timesteps(self):
537
+ return self._num_timesteps
538
+
539
+ @property
540
+ def current_timestep(self):
541
+ return self._current_timestep
542
+
543
+ @property
544
+ def interrupt(self):
545
+ return self._interrupt
546
+
547
+ @torch.no_grad()
548
+ def __call__(
549
+ self,
550
+ prompt: Union[str, List[str]] = None,
551
+ prompt_2: Optional[Union[str, List[str]]] = None,
552
+ negative_prompt: Union[str, List[str]] = None,
553
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
554
+ true_cfg_scale: float = 1.0,
555
+ height: Optional[int] = None,
556
+ width: Optional[int] = None,
557
+ num_inference_steps: int = 28,
558
+ sigmas: Optional[List[float]] = None,
559
+ guidance_scale: float = 3.5,
560
+ num_images_per_prompt: Optional[int] = 1,
561
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
562
+ latents: Optional[torch.FloatTensor] = None,
563
+ prompt_embeds: Optional[torch.FloatTensor] = None,
564
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
565
+ ip_adapter_image: Optional[PipelineImageInput] = None,
566
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
567
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
568
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
569
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
570
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
571
+ output_type: Optional[str] = "pil",
572
+ return_dict: bool = True,
573
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
574
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
575
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
576
+ max_sequence_length: int = 512,
577
+ ):
578
+ height = height or self.default_sample_size * self.vae_scale_factor
579
+ width = width or self.default_sample_size * self.vae_scale_factor
580
+
581
+ # 1. Check inputs. Raise error if not correct
582
+ self.check_inputs(
583
+ prompt,
584
+ prompt_2,
585
+ height,
586
+ width,
587
+ negative_prompt=negative_prompt,
588
+ negative_prompt_2=negative_prompt_2,
589
+ prompt_embeds=prompt_embeds,
590
+ negative_prompt_embeds=negative_prompt_embeds,
591
+ pooled_prompt_embeds=pooled_prompt_embeds,
592
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
593
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
594
+ max_sequence_length=max_sequence_length,
595
+ )
596
+
597
+ self._guidance_scale = guidance_scale
598
+ self._joint_attention_kwargs = joint_attention_kwargs
599
+ self._current_timestep = None
600
+ self._interrupt = False
601
+
602
+ # 2. Define call parameters
603
+ if prompt is not None and isinstance(prompt, str):
604
+ batch_size = 1
605
+ elif prompt is not None and isinstance(prompt, list):
606
+ batch_size = len(prompt)
607
+ else:
608
+ batch_size = prompt_embeds.shape[0]
609
+
610
+ device = self._execution_device
611
+
612
+ lora_scale = (
613
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
614
+ )
615
+ has_neg_prompt = negative_prompt is not None or (
616
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
617
+ )
618
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
619
+ (
620
+ prompt_embeds,
621
+ pooled_prompt_embeds,
622
+ text_ids,
623
+ ) = self.encode_prompt(
624
+ prompt=prompt,
625
+ prompt_2=prompt_2,
626
+ prompt_embeds=prompt_embeds,
627
+ pooled_prompt_embeds=pooled_prompt_embeds,
628
+ device=device,
629
+ num_images_per_prompt=num_images_per_prompt,
630
+ max_sequence_length=max_sequence_length,
631
+ lora_scale=lora_scale,
632
+ )
633
+ if do_true_cfg:
634
+ (
635
+ negative_prompt_embeds,
636
+ negative_pooled_prompt_embeds,
637
+ _,
638
+ ) = self.encode_prompt(
639
+ prompt=negative_prompt,
640
+ prompt_2=negative_prompt_2,
641
+ prompt_embeds=negative_prompt_embeds,
642
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
643
+ device=device,
644
+ num_images_per_prompt=num_images_per_prompt,
645
+ max_sequence_length=max_sequence_length,
646
+ lora_scale=lora_scale,
647
+ )
648
+
649
+ # 4. Prepare latent variables
650
+ num_channels_latents = 16 #self.transformer.config.in_channels // 4
651
+ latents, latent_image_ids = self.prepare_latents(
652
+ batch_size * num_images_per_prompt,
653
+ num_channels_latents,
654
+ height,
655
+ width,
656
+ prompt_embeds.dtype,
657
+ device,
658
+ generator,
659
+ latents,
660
+ )
661
+
662
+ # 5. Prepare timesteps
663
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
664
+ image_seq_len = latents.shape[1]
665
+ mu = calculate_shift(
666
+ image_seq_len,
667
+ self.scheduler.config.get("base_image_seq_len", 256),
668
+ self.scheduler.config.get("max_image_seq_len", 4096),
669
+ self.scheduler.config.get("base_shift", 0.5),
670
+ self.scheduler.config.get("max_shift", 1.16),
671
+ )
672
+ timesteps, num_inference_steps = retrieve_timesteps(
673
+ self.scheduler,
674
+ num_inference_steps,
675
+ device,
676
+ sigmas=sigmas,
677
+ mu=mu,
678
+ )
679
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
680
+ self._num_timesteps = len(timesteps)
681
+
682
+ # handle guidance
683
+ if False: #self.transformer.config.guidance_embeds:
684
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
685
+ guidance = guidance.expand(latents.shape[0])
686
+ else:
687
+ guidance = None
688
+
689
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
690
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
691
+ ):
692
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
693
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
694
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
695
+ ):
696
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
697
+
698
+ if self.joint_attention_kwargs is None:
699
+ self._joint_attention_kwargs = {}
700
+
701
+ image_embeds = None
702
+ negative_image_embeds = None
703
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
704
+ image_embeds = self.prepare_ip_adapter_image_embeds(
705
+ ip_adapter_image,
706
+ ip_adapter_image_embeds,
707
+ device,
708
+ batch_size * num_images_per_prompt,
709
+ )
710
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
711
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
712
+ negative_ip_adapter_image,
713
+ negative_ip_adapter_image_embeds,
714
+ device,
715
+ batch_size * num_images_per_prompt,
716
+ )
717
+
718
+ # 6. Denoising loop
719
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
720
+ for i, t in enumerate(timesteps):
721
+ if self.interrupt:
722
+ continue
723
+
724
+ self._current_timestep = t
725
+ if image_embeds is not None:
726
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
727
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
728
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
729
+
730
+ noise_pred = self.transformer(
731
+ hidden_states=latents,
732
+ timestep=timestep / 1000,
733
+ guidance=guidance,
734
+ pooled_projections=pooled_prompt_embeds,
735
+ encoder_hidden_states=prompt_embeds,
736
+ txt_ids=text_ids,
737
+ img_ids=latent_image_ids,
738
+ joint_attention_kwargs=self.joint_attention_kwargs,
739
+ return_dict=False,
740
+ )[0]
741
+
742
+ if do_true_cfg:
743
+ if negative_image_embeds is not None:
744
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
745
+ neg_noise_pred = self.transformer(
746
+ hidden_states=latents,
747
+ timestep=timestep / 1000,
748
+ guidance=guidance,
749
+ pooled_projections=negative_pooled_prompt_embeds,
750
+ encoder_hidden_states=negative_prompt_embeds,
751
+ txt_ids=text_ids,
752
+ img_ids=latent_image_ids,
753
+ joint_attention_kwargs=self.joint_attention_kwargs,
754
+ return_dict=False,
755
+ )[0]
756
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
757
+
758
+ # compute the previous noisy sample x_t -> x_t-1
759
+ latents_dtype = latents.dtype
760
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
761
+
762
+ if latents.dtype != latents_dtype:
763
+ if torch.backends.mps.is_available():
764
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
765
+ latents = latents.to(latents_dtype)
766
+
767
+ if callback_on_step_end is not None:
768
+ callback_kwargs = {}
769
+ for k in callback_on_step_end_tensor_inputs:
770
+ callback_kwargs[k] = locals()[k]
771
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
772
+
773
+ latents = callback_outputs.pop("latents", latents)
774
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
775
+
776
+ # call the callback, if provided
777
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
778
+ progress_bar.update()
779
+
780
+
781
+ self._current_timestep = None
782
+
783
+ if output_type == "latent":
784
+ image = latents
785
+ else:
786
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
787
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
788
+ image = self.vae.decode(latents, return_dict=False)[0]
789
+ image = self.image_processor.postprocess(image, output_type=output_type)
790
+
791
+ # Offload all models
792
+ self.maybe_free_model_hooks()
793
+
794
+ if not return_dict:
795
+ return (image,)
796
+
797
+ return FluxPipelineOutput(images=image)
798
+
799
+ def get_example_inputs():
800
+ example_inputs = torch.load("/root/.cache/huggingface/hub/models--sayakpaul--flux.1-dev-int8-aot-compiled/snapshots/3b4f77e9752dd278c432870d101b958c902af2c9/serialized_inputs.pt", weights_only=True)
801
+ example_inputs = {k: v.to("cuda") for k, v in example_inputs.items()}
802
+ example_inputs.update({"joint_attention_kwargs": None, "return_dict": False})
803
+ example_inputs.update({"guidance": None})
804
+ return example_inputs
805
+
806
+ @torch.no_grad()
807
+ def f(model, **kwargs):
808
+ return model(**kwargs)
809
+
810
+ def benchmark_fn(f, *args, **kwargs):
811
+ t0 = benchmark.Timer(
812
+ stmt="f(*args, **kwargs)",
813
+ globals={"args": args, "kwargs": kwargs, "f": f},
814
+ num_threads=torch.get_num_threads(),
815
+ )
816
+ return f"{(t0.blocked_autorange().mean):.3f}"
817
+
818
+
819
+ def prepare_latents(batch_size, height, width, num_channels_latents=1):
820
+ vae_scale_factor = 16
821
+ height = 2 * (int(height) // vae_scale_factor)
822
+ width = 2 * (int(width) // vae_scale_factor)
823
+ shape = (batch_size, num_channels_latents, height, width)
824
+ pre_hidden_states = torch.randn(shape, dtype=torch.bfloat16, device="cuda")
825
+ hidden_states = FluxPipeline._pack_latents(
826
+ pre_hidden_states, batch_size, num_channels_latents, height, width
827
+ )
828
+ return hidden_states
829
+
830
+ def get_example_inputs(batch_size, height, width, num_channels_latents=1):
831
+ hidden_states = prepare_latents(batch_size, height, width, num_channels_latents)
832
+ num_img_sequences = hidden_states.shape[1]
833
+ example_inputs = {
834
+ "hidden_states": hidden_states,
835
+ "encoder_hidden_states": torch.randn(batch_size, 512, 4096, dtype=torch.bfloat16, device="cuda"),
836
+ "pooled_projections": torch.randn(batch_size, 768, dtype=torch.bfloat16, device="cuda"),
837
+ "timestep": torch.tensor([1.0], device="cuda").expand(batch_size),
838
+ "img_ids": torch.randn(num_img_sequences, 3, dtype=torch.bfloat16, device="cuda"),
839
+ "txt_ids": torch.randn(512, 3, dtype=torch.bfloat16, device="cuda"),
840
+ "guidance": torch.tensor([3.5], device="cuda").expand(batch_size),
841
+ "return_dict": False,
842
+ }
843
+ example_inputs.update({"joint_attention_kwargs": None, "return_dict": False})
844
+ example_inputs.update({"guidance": None})
845
+ return example_inputs
846
+
847
+
848
  def load_pipeline() -> Pipeline:
849
  model_name = "manbeast3b/Flux.1.Schnell-full-quant1"
850
  revision = "e7ddf488a4ea8a3cba05db5b8d06e7e0feb826a2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851
 
852
  hub_model_dir = os.path.join(
853
  HF_HUB_CACHE,
 
872
  )
873
  # pipeline.vae = torch.compile(vae)
874
  pipeline.to("cuda")
875
+
876
+ path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--Flux.1.schnell-compiled_transformer/snapshots/a59b2d689b775f3a4177c5ade0a63e5b6148aa03/bs_1_1024.pt2")
877
+ inputs1 = get_example_inputs()
878
+ print(f"AoT pre compiled path is {path}")
879
+
880
+ transformer = torch._inductor.aoti_load_package(path)
881
+
882
+ for _ in range(2):
883
+ _ = transformer(**inputs1)[0]
884
+
885
+ pipeline.transformer = transformer
886
 
887
  warmup_ = "controllable varied focus thai warriors entertainment blue golden pink soft tough padthai"
888
  for _ in range(1):