hnpinq commited on
Commit
0c25100
·
verified ·
1 Parent(s): 5b71dc6

Upload folder using huggingface_hub

Browse files
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (158 Bytes). View file
 
src/__pycache__/pipeline_flux_tryon.cpython-311.pyc ADDED
Binary file (18.7 kB). View file
 
src/pipeline_flux_tryon.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import inspect
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import PIL.Image
7
+ from PIL import Image
8
+ import torch
9
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
10
+
11
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
12
+ from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
13
+ from diffusers.models.autoencoders import AutoencoderKL
14
+ from diffusers.models.transformers import FluxTransformer2DModel
15
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
16
+ from diffusers.utils import (
17
+ USE_PEFT_BACKEND,
18
+ is_torch_xla_available,
19
+ logging,
20
+ replace_example_docstring,
21
+ scale_lora_layers,
22
+ unscale_lora_layers,
23
+ )
24
+ from diffusers.utils.torch_utils import randn_tensor
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
27
+
28
+ from diffusers.pipelines import FluxInpaintPipeline
29
+ from diffusers.pipelines.flux.pipeline_flux_inpaint import calculate_shift, retrieve_latents, retrieve_timesteps
30
+
31
+
32
+ class FluxTryonPipeline(FluxInpaintPipeline):
33
+ @staticmethod
34
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
35
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype, target_width=-1, tryon=False):
36
+ latent_image_ids = torch.zeros(height, width, 3)
37
+ if target_width==-1:
38
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
39
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
40
+ else:
41
+ latent_image_ids[:, target_width:, 0] = 1
42
+ # height keep as before
43
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
44
+ if tryon:
45
+ latent_image_ids[:, target_width*2:, 0] = 2
46
+ # left
47
+ latent_image_ids[:, :target_width, 2] = latent_image_ids[:, :target_width, 2] + torch.arange(target_width)[None, :]
48
+ # right
49
+ latent_image_ids[:, target_width:, 2] = latent_image_ids[:, target_width:, 2] + torch.arange(width-target_width)[None, :]
50
+ else:
51
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
52
+
53
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
54
+
55
+ latent_image_ids = latent_image_ids.reshape(
56
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
57
+ )
58
+
59
+ return latent_image_ids.to(device=device, dtype=dtype)
60
+
61
+
62
+ def prepare_latents(
63
+ self,
64
+ image,
65
+ timestep,
66
+ batch_size,
67
+ num_channels_latents,
68
+ height,
69
+ width,
70
+ target_width,
71
+ tryon,
72
+ dtype,
73
+ device,
74
+ generator,
75
+ latents=None,
76
+ ):
77
+ if isinstance(generator, list) and len(generator) != batch_size:
78
+ raise ValueError(
79
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
80
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
81
+ )
82
+
83
+ # VAE applies 8x compression on images but we must also account for packing which requires
84
+ # latent height and width to be divisible by 2.
85
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
86
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
87
+ shape = (batch_size, num_channels_latents, height, width)
88
+ sp = 2 * (int(target_width) // (self.vae_scale_factor * 2))//2 # -1
89
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype, sp, tryon)
90
+
91
+ image = image.to(device=device, dtype=dtype)
92
+ # image_latents = self._encode_vae_image(image=image, generator=generator)
93
+ img_parts = [image[:,:,:,:target_width], image[:,:,:,target_width:]]
94
+ image_latents = [self._encode_vae_image(image=img, generator=generator) for img in img_parts]
95
+ image_latents = torch.cat(image_latents, dim=-1)
96
+
97
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
98
+ # expand init_latents for batch_size
99
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
100
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
101
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
102
+ raise ValueError(
103
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
104
+ )
105
+ else:
106
+ image_latents = torch.cat([image_latents], dim=0)
107
+
108
+ if latents is None:
109
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
110
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
111
+ else:
112
+ noise = latents.to(device)
113
+ latents = noise
114
+
115
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
116
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
117
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
118
+ return latents, noise, image_latents, latent_image_ids
119
+
120
+ def prepare_mask_latents(
121
+ self,
122
+ mask,
123
+ masked_image,
124
+ batch_size,
125
+ num_channels_latents,
126
+ num_images_per_prompt,
127
+ height,
128
+ width,
129
+ dtype,
130
+ device,
131
+ generator,
132
+ ):
133
+ # VAE applies 8x compression on images but we must also account for packing which requires
134
+ # latent height and width to be divisible by 2.
135
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
136
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
137
+ # resize the mask to latents shape as we concatenate the mask to the latents
138
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
139
+ # and half precision
140
+ mask = torch.nn.functional.interpolate(mask, size=(height, width), mode="nearest")
141
+ mask = mask.to(device=device, dtype=dtype)
142
+
143
+ batch_size = batch_size * num_images_per_prompt
144
+
145
+ masked_image = masked_image.to(device=device, dtype=dtype)
146
+
147
+ if masked_image.shape[1] == 16:
148
+ masked_image_latents = masked_image
149
+ else:
150
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
151
+
152
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
153
+
154
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
155
+ if mask.shape[0] < batch_size:
156
+ if not batch_size % mask.shape[0] == 0:
157
+ raise ValueError(
158
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
159
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
160
+ " of masks that you pass is divisible by the total requested batch size."
161
+ )
162
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
163
+ if masked_image_latents.shape[0] < batch_size:
164
+ if not batch_size % masked_image_latents.shape[0] == 0:
165
+ raise ValueError(
166
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
167
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
168
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
169
+ )
170
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
171
+
172
+ # aligning device to prevent device errors when concating it with the latent model input
173
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
174
+ masked_image_latents = self._pack_latents(
175
+ masked_image_latents,
176
+ batch_size,
177
+ num_channels_latents,
178
+ height,
179
+ width,
180
+ )
181
+ mask = self._pack_latents(
182
+ mask.repeat(1, num_channels_latents, 1, 1),
183
+ batch_size,
184
+ num_channels_latents,
185
+ height,
186
+ width,
187
+ )
188
+
189
+ return mask, masked_image_latents
190
+
191
+ @torch.no_grad()
192
+ def __call__(
193
+ self,
194
+ prompt: Union[str, List[str]] = None,
195
+ prompt_2: Optional[Union[str, List[str]]] = None,
196
+ image: PipelineImageInput = None,
197
+ mask_image: PipelineImageInput = None,
198
+ masked_image_latents: PipelineImageInput = None,
199
+ height: Optional[int] = None,
200
+ width: Optional[int] = None,
201
+ target_width: Optional[int] = None,
202
+ tryon: bool = False,
203
+ padding_mask_crop: Optional[int] = None,
204
+ strength: float = 0.6,
205
+ num_inference_steps: int = 28,
206
+ timesteps: List[int] = None,
207
+ guidance_scale: float = 7.0,
208
+ num_images_per_prompt: Optional[int] = 1,
209
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
210
+ latents: Optional[torch.FloatTensor] = None,
211
+ prompt_embeds: Optional[torch.FloatTensor] = None,
212
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
213
+ output_type: Optional[str] = "pil",
214
+ return_dict: bool = True,
215
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
216
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
217
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
218
+ max_sequence_length: int = 512,
219
+ ):
220
+ height = height or self.default_sample_size * self.vae_scale_factor
221
+ width = width or self.default_sample_size * self.vae_scale_factor
222
+
223
+ # 1. Check inputs. Raise error if not correct
224
+ self.check_inputs(
225
+ prompt,
226
+ prompt_2,
227
+ image,
228
+ mask_image,
229
+ strength,
230
+ height,
231
+ width,
232
+ output_type=output_type,
233
+ prompt_embeds=prompt_embeds,
234
+ pooled_prompt_embeds=pooled_prompt_embeds,
235
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
236
+ padding_mask_crop=padding_mask_crop,
237
+ max_sequence_length=max_sequence_length,
238
+ )
239
+
240
+ self._guidance_scale = guidance_scale
241
+ self._joint_attention_kwargs = joint_attention_kwargs
242
+ self._interrupt = False
243
+
244
+ # 2. Preprocess mask and image
245
+ if padding_mask_crop is not None:
246
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
247
+ resize_mode = "fill"
248
+ else:
249
+ crops_coords = None
250
+ resize_mode = "default"
251
+
252
+ original_image = image
253
+ init_image = self.image_processor.preprocess(
254
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
255
+ )
256
+ init_image = init_image.to(dtype=torch.float32)
257
+
258
+ # 3. Define call parameters
259
+ if prompt is not None and isinstance(prompt, str):
260
+ batch_size = 1
261
+ elif prompt is not None and isinstance(prompt, list):
262
+ batch_size = len(prompt)
263
+ else:
264
+ batch_size = prompt_embeds.shape[0]
265
+
266
+ device = self._execution_device
267
+
268
+ lora_scale = (
269
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
270
+ )
271
+ (
272
+ prompt_embeds,
273
+ pooled_prompt_embeds,
274
+ text_ids,
275
+ ) = self.encode_prompt(
276
+ prompt=prompt,
277
+ prompt_2=prompt_2,
278
+ prompt_embeds=prompt_embeds,
279
+ pooled_prompt_embeds=pooled_prompt_embeds,
280
+ device=device,
281
+ num_images_per_prompt=num_images_per_prompt,
282
+ max_sequence_length=max_sequence_length,
283
+ lora_scale=lora_scale,
284
+ )
285
+
286
+ # 4.Prepare timesteps
287
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
288
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
289
+ mu = calculate_shift(
290
+ image_seq_len,
291
+ self.scheduler.config.base_image_seq_len,
292
+ self.scheduler.config.max_image_seq_len,
293
+ self.scheduler.config.base_shift,
294
+ self.scheduler.config.max_shift,
295
+ )
296
+ timesteps, num_inference_steps = retrieve_timesteps(
297
+ self.scheduler,
298
+ num_inference_steps,
299
+ device,
300
+ timesteps,
301
+ sigmas,
302
+ mu=mu,
303
+ )
304
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
305
+
306
+ if num_inference_steps < 1:
307
+ raise ValueError(
308
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
309
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
310
+ )
311
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
312
+
313
+ # 5. Prepare latent variables
314
+ num_channels_latents = self.transformer.config.in_channels // 4
315
+ num_channels_transformer = self.transformer.config.in_channels
316
+
317
+ latents, noise, image_latents, latent_image_ids= self.prepare_latents(
318
+ init_image,
319
+ latent_timestep,
320
+ batch_size * num_images_per_prompt,
321
+ num_channels_latents,
322
+ height,
323
+ width,
324
+ target_width,
325
+ tryon,
326
+ prompt_embeds.dtype,
327
+ device,
328
+ generator,
329
+ latents,
330
+ )
331
+
332
+ mask_condition = self.mask_processor.preprocess(
333
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
334
+ )
335
+
336
+ if masked_image_latents is None:
337
+ masked_image = init_image * (mask_condition < 0.5)
338
+ else:
339
+ masked_image = masked_image_latents
340
+
341
+ mask, masked_image_latents = self.prepare_mask_latents(
342
+ mask_condition,
343
+ masked_image,
344
+ batch_size,
345
+ num_channels_latents,
346
+ num_images_per_prompt,
347
+ height,
348
+ width,
349
+ prompt_embeds.dtype,
350
+ device,
351
+ generator,
352
+ )
353
+
354
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
355
+ self._num_timesteps = len(timesteps)
356
+
357
+ # handle guidance
358
+ if self.transformer.config.guidance_embeds:
359
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
360
+ guidance = guidance.expand(latents.shape[0])
361
+ else:
362
+ guidance = None
363
+
364
+ # 6. Denoising loop
365
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
366
+ for i, t in enumerate(timesteps):
367
+ if self.interrupt:
368
+ continue
369
+
370
+ # for 64 channel transformer only.
371
+ init_latents_proper = image_latents
372
+ init_mask = mask
373
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
374
+
375
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
376
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
377
+ noise_pred = self.transformer(
378
+ hidden_states=latents,
379
+ timestep=timestep / 1000,
380
+ guidance=guidance,
381
+ pooled_projections=pooled_prompt_embeds,
382
+ encoder_hidden_states=prompt_embeds,
383
+ txt_ids=text_ids,
384
+ img_ids=latent_image_ids,
385
+ joint_attention_kwargs=self.joint_attention_kwargs,
386
+ return_dict=False,
387
+ )[0]
388
+
389
+ # compute the previous noisy sample x_t -> x_t-1
390
+ latents_dtype = latents.dtype
391
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
392
+
393
+ '''
394
+ # for 64 channel transformer only.
395
+ init_latents_proper = image_latents
396
+ init_mask = mask
397
+
398
+ # NOTE: we just use clean latents
399
+ # if i < len(timesteps) - 1:
400
+ # noise_timestep = timesteps[i + 1]
401
+ # init_latents_proper = self.scheduler.scale_noise(
402
+ # init_latents_proper, torch.tensor([noise_timestep]), noise
403
+ # )
404
+
405
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
406
+ '''
407
+
408
+ if latents.dtype != latents_dtype:
409
+ if torch.backends.mps.is_available():
410
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
411
+ latents = latents.to(latents_dtype)
412
+
413
+ if callback_on_step_end is not None:
414
+ callback_kwargs = {}
415
+ for k in callback_on_step_end_tensor_inputs:
416
+ callback_kwargs[k] = locals()[k]
417
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
418
+
419
+ latents = callback_outputs.pop("latents", latents)
420
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
421
+
422
+ # call the callback, if provided
423
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
424
+ progress_bar.update()
425
+
426
+ # if XLA_AVAILABLE:
427
+ # xm.mark_step()
428
+ # latents = (1 - mask) * image_latents + mask * latents
429
+
430
+ if output_type == "latent":
431
+ image = latents
432
+ else:
433
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
434
+ latents = latents[:,:,:,:target_width//self.vae_scale_factor]
435
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
436
+ image = self.vae.decode(latents.to(device=self.vae.device, dtype=self.vae.dtype), return_dict=False)[0]
437
+ image = self.image_processor.postprocess(image, output_type=output_type)
438
+
439
+ # Offload all models
440
+ self.maybe_free_model_hooks()
441
+
442
+ if not return_dict:
443
+ return (image,)
444
+
445
+ return FluxPipelineOutput(images=image)