lime-j commited on
Commit
f8e1897
·
1 Parent(s): ae8c337

init with GenSIRR

Browse files
__pycache__/optimization.cpython-310.pyc ADDED
Binary file (1.72 kB). View file
 
__pycache__/optimization_utils.cpython-310.pyc ADDED
Binary file (4.75 kB). View file
 
__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (18 kB). View file
 
app.py CHANGED
@@ -10,15 +10,32 @@ import torch
10
  import random
11
  from PIL import Image
12
 
13
- from diffusers import FluxKontextPipeline
14
  from diffusers.utils import load_image
15
 
16
  from optimization import optimize_pipeline_
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
 
19
 
20
- pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
21
- optimize_pipeline_(pipe, image=Image.new("RGB", (512, 512)), prompt='prompt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  @spaces.GPU
24
  def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
@@ -64,24 +81,15 @@ def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5
64
  if randomize_seed:
65
  seed = random.randint(0, MAX_SEED)
66
 
67
- if input_image:
68
- input_image = input_image.convert("RGB")
69
- image = pipe(
70
- image=input_image,
71
- prompt=prompt,
72
- guidance_scale=guidance_scale,
73
- width = input_image.size[0],
74
- height = input_image.size[1],
75
- num_inference_steps=steps,
76
- generator=torch.Generator().manual_seed(seed),
77
- ).images[0]
78
- else:
79
- image = pipe(
80
- prompt=prompt,
81
- guidance_scale=guidance_scale,
82
- num_inference_steps=steps,
83
- generator=torch.Generator().manual_seed(seed),
84
- ).images[0]
85
  return image, seed, gr.Button(visible=True)
86
 
87
  @spaces.GPU
@@ -147,17 +155,6 @@ Image editing and manipulation model guidance-distilled from FLUX.1 Kontext [pro
147
  reuse_button = gr.Button("Reuse this image", visible=False)
148
 
149
 
150
- examples = gr.Examples(
151
- examples=[
152
- ["flowers.png", "turn the flowers into sunflowers"],
153
- ["monster.png", "make this monster ride a skateboard on the beach"],
154
- ["cat.png", "make this cat happy"]
155
- ],
156
- inputs=[input_image, prompt],
157
- outputs=[result, seed],
158
- fn=infer_example,
159
- cache_examples="lazy"
160
- )
161
 
162
  gr.on(
163
  triggers=[run_button.click, prompt.submit],
 
10
  import random
11
  from PIL import Image
12
 
13
+ from pipeline import GenSIRR
14
  from diffusers.utils import load_image
15
 
16
  from optimization import optimize_pipeline_
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
+ from huggingface_hub import hf_hub_download
20
 
21
+ def load_deepspeed_weights(model, checkpoint_path) -> None:
22
+ """Load LoRA weights from a DeepSpeed ZeRO Stage 2 checkpoint into the model."""
23
+ tensor_path = checkpoint_path
24
+ # LOGGER.info("Loading ZeRO checkpoint from %s", tensor_path)
25
+ raw_state = torch.load(tensor_path, map_location="cpu")
26
+ module_state: Dict[str, torch.Tensor] = raw_state.get("module")
27
+ if module_state is None:
28
+ raise KeyError("Checkpoint is missing the 'module' state dict")
29
+
30
+ # Remove the Lightning prefix so it matches the FluxKontext state dict.
31
+ cleaned_state = {key[len("net_g."):]: value for key, value in module_state.items() if key.startswith("net_g.")}
32
+
33
+ missing, unexpected = model.load_state_dict(cleaned_state, strict=False)
34
+
35
+ pipe = GenSIRR("black-forest-labs/FLUX.1-Kontext-dev")
36
+ load_deepspeed_weights(pipe, hf_hub_download(repo_id='lime-j/GenSIRR', filename="GenSIRR.pt"))
37
+ pipe = pipe.to("cuda")
38
+ # optimize_pipeline_(pipe, image=Image.new("RGB", (512, 512)), prompt='prompt')
39
 
40
  @spaces.GPU
41
  def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
 
81
  if randomize_seed:
82
  seed = random.randint(0, MAX_SEED)
83
 
84
+
85
+ input_image = input_image.convert("RGB")
86
+ image = pipe(
87
+ image=input_image,
88
+ width = input_image.size[0],
89
+ height = input_image.size[1],
90
+ num_inference_steps=steps,
91
+ generator=torch.Generator().manual_seed(seed),
92
+ ).images[0]
 
 
 
 
 
 
 
 
 
93
  return image, seed, gr.Button(visible=True)
94
 
95
  @spaces.GPU
 
155
  reuse_button = gr.Button("Reuse this image", visible=False)
156
 
157
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  gr.on(
160
  triggers=[run_button.click, prompt.submit],
pipeline.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
4
+ from typing import Dict, Any, Optional, List, Callable, Union
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ from diffusers import FluxKontextPipeline
9
+ from diffusers.image_processor import PipelineImageInput
10
+ from diffusers.utils import (
11
+ USE_PEFT_BACKEND,
12
+ is_torch_xla_available,
13
+ logging,
14
+ scale_lora_layers,
15
+ unscale_lora_layers,
16
+ )
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from peft import LoraConfig, LoraModel, get_peft_model
19
+
20
+ torch.set_float32_matmul_precision('medium')
21
+ if is_torch_xla_available():
22
+ import torch_xla.core.xla_model as xm
23
+ XLA_AVAILABLE = True
24
+ else:
25
+ XLA_AVAILABLE = False
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ PREFERRED_KONTEXT_RESOLUTIONS = [
30
+ (672, 1568), (688, 1504), (720, 1456), (752, 1392), (800, 1328),
31
+ (832, 1248), (880, 1184), (944, 1104), (1024, 1024), (1104, 944),
32
+ (1184, 880), (1248, 832), (1328, 800), (1392, 752), (1456, 720),
33
+ (1504, 688), (1568, 672),
34
+ ]
35
+
36
+
37
+ def _resolve_vae_path(user_path: Optional[str] = None) -> str:
38
+ """Resolve where to load the VAE weights from."""
39
+ repo_root = Path(__file__).resolve().parents[2]
40
+ candidates = [
41
+ user_path,
42
+ os.environ.get("FLUX_VAE_PATH"),
43
+ os.environ.get("VAE_PATH"),
44
+ repo_root / "vae_merged",
45
+ "/home/s1023244038/XReflection/vae_merged",
46
+ ]
47
+
48
+ for candidate in candidates:
49
+ if candidate is None:
50
+ continue
51
+ candidate_path = Path(candidate).expanduser()
52
+ if candidate_path.exists():
53
+ return str(candidate_path)
54
+
55
+ raise FileNotFoundError(
56
+ "Could not locate VAE weights. Please set the `FLUX_VAE_PATH` "
57
+ "environment variable to the directory that contains the merged VAE."
58
+ )
59
+
60
+ def retrieve_latents(
61
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
62
+ ):
63
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
64
+ return encoder_output.latent_dist.sample(generator)
65
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
66
+ return encoder_output.latent_dist.mode()
67
+ elif hasattr(encoder_output, "latents"):
68
+ return encoder_output.latents
69
+ else:
70
+ raise AttributeError("Could not access latents of provided encoder_output")
71
+
72
+ def calculate_shift(image_seq_len, base_image_seq_len, max_image_seq_len, base_shift, max_shift):
73
+ return base_shift + (max_shift - base_shift) * (image_seq_len - base_image_seq_len) / (
74
+ max_image_seq_len - base_image_seq_len
75
+ )
76
+
77
+ def retrieve_timesteps(
78
+ scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None,
79
+ timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs,
80
+ ):
81
+ if timesteps is not None and sigmas is not None:
82
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
83
+
84
+ if timesteps is not None:
85
+ scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=device, timesteps=timesteps, **kwargs)
86
+ timesteps = scheduler.timesteps
87
+ num_inference_steps = len(timesteps)
88
+ elif num_inference_steps is not None:
89
+ scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
90
+ timesteps = scheduler.timesteps
91
+ elif sigmas is not None:
92
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
93
+ timesteps = scheduler.timesteps
94
+ num_inference_steps = len(timesteps)
95
+ else:
96
+ raise ValueError("Either `num_inference_steps` or `timesteps` or `sigmas` has to be passed.")
97
+ return timesteps, num_inference_steps
98
+
99
+
100
+
101
+ class GenSIRR(nn.Module):
102
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
103
+
104
+ def __init__(self, model_path, train_dit: bool = True, vae_path: Optional[str] = None):
105
+ super().__init__()
106
+ self.train_dit = train_dit
107
+ pipe = FluxKontextPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
108
+ self.dtype = torch.bfloat16
109
+ self.vae = pipe.vae
110
+
111
+ self.text_encoder = pipe.text_encoder
112
+ self.tokenizer = pipe.tokenizer
113
+ self.text_encoder_2 = None #pipe.text_encoder_2
114
+ self.tokenizer_2 = None #pipe.tokenizer_2
115
+ self.transformer = pipe.transformer
116
+ self.scheduler = pipe.scheduler
117
+ # self.image_encoder = pipe.image_encoder.to("cuda")
118
+ self.image_processor = pipe.image_processor
119
+
120
+ self.latent_channels = self.transformer.config.in_channels // 4
121
+ self.vae_scale_factor = pipe.vae_scale_factor
122
+ self.joint_attention_kwargs = getattr(pipe, '_joint_attention_kwargs', None)
123
+ self._execution_device = pipe._execution_device
124
+ self.default_sample_size = pipe.default_sample_size
125
+ self.interrupt = False
126
+ self.tokenizer_max_length = pipe.tokenizer_max_length
127
+ self.transformer.enable_gradient_checkpointing()
128
+ self.cached_prompt_embeds = torch.nn.Parameter(torch.load("prompt_embeds.pth", map_location='cpu'))
129
+ self.cached_pooled_prompt_embeds = torch.nn.Parameter(torch.load("pooled_prompt_embeds.pth", map_location='cpu'))
130
+ self.cached_text_ids = torch.nn.Parameter(torch.load("text_ids.pth", map_location='cpu'))
131
+ @staticmethod
132
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
133
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
134
+ latent_image_ids = torch.zeros(height, width, 3)
135
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
136
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
137
+
138
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
139
+
140
+ latent_image_ids = latent_image_ids.reshape(
141
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
142
+ )
143
+
144
+ return latent_image_ids.to(device=device, dtype=dtype)
145
+
146
+ @staticmethod
147
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
148
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
149
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
150
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
151
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
152
+
153
+ return latents
154
+
155
+ @staticmethod
156
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
157
+ def _unpack_latents(latents, height, width, vae_scale_factor):
158
+ batch_size, num_patches, channels = latents.shape
159
+
160
+ # VAE applies 8x compression on images but we must also account for packing which requires
161
+ # latent height and width to be divisible by 2.
162
+ height = 2 * (int(height) // (vae_scale_factor * 2))
163
+ width = 2 * (int(width) // (vae_scale_factor * 2))
164
+
165
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
166
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
167
+
168
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
169
+
170
+ return latents
171
+
172
+ def progress_bar(self, iterable):
173
+ return iterable
174
+
175
+ def maybe_free_model_hooks(self):
176
+ pass
177
+
178
+
179
+ def check_inputs(
180
+ self, prompt, prompt_2, height, width, negative_prompt=None, negative_prompt_2=None,
181
+ prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None,
182
+ negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None,
183
+ ):
184
+ if height % 8 != 0 or width % 8 != 0:
185
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
186
+ if callback_on_step_end_tensor_inputs is not None and not all(k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs):
187
+ raise ValueError(f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}")
188
+ if prompt is not None and prompt_embeds is not None:
189
+ raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.")
190
+ if prompt_2 is not None and prompt_embeds is not None:
191
+ raise ValueError("Cannot forward both `prompt_2` and `prompt_embeds`.")
192
+ if prompt is None and prompt_embeds is None:
193
+ raise ValueError("Provide either `prompt` or `prompt_embeds`.")
194
+ if prompt is not None and not isinstance(prompt, (str, list)):
195
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
196
+ if prompt_2 is not None and not isinstance(prompt_2, (str, list)):
197
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
198
+ if negative_prompt is not None and negative_prompt_embeds is not None:
199
+ raise ValueError("Cannot forward both `negative_prompt` and `negative_prompt_embeds`.")
200
+ if negative_prompt_2 is not None and negative_prompt_embeds is not None:
201
+ raise ValueError("Cannot forward both `negative_prompt_2` and `negative_prompt_embeds`.")
202
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
203
+ raise ValueError("If `prompt_embeds` are provided, `pooled_prompt_embeds` must also be passed.")
204
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
205
+ raise ValueError("If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` must also be passed.")
206
+
207
+ def _get_t5_prompt_embeds(
208
+ self,
209
+ prompt: Union[str, List[str]] = None,
210
+ num_images_per_prompt: int = 1,
211
+ max_sequence_length: int = 512,
212
+ device: Optional[torch.device] = None,
213
+ dtype: Optional[torch.dtype] = None,
214
+ ):
215
+ device = self.text_encoder_2.device if self.text_encoder_2 is not None else self.text_encoder.device
216
+ dtype = dtype or self.text_encoder.dtype
217
+
218
+ prompt = [prompt] if isinstance(prompt, str) else prompt
219
+ batch_size = len(prompt)
220
+
221
+
222
+ text_inputs = self.tokenizer_2(
223
+ prompt,
224
+ padding="max_length",
225
+ max_length=max_sequence_length,
226
+ truncation=True,
227
+ return_length=False,
228
+ return_overflowing_tokens=False,
229
+ return_tensors="pt",
230
+ )
231
+ text_input_ids = text_inputs.input_ids
232
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
233
+
234
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
235
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
236
+ logger.warning(
237
+ "The following part of your input was truncated because `max_sequence_length` is set to "
238
+ f" {max_sequence_length} tokens: {removed_text}"
239
+ )
240
+
241
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
242
+
243
+ dtype = self.text_encoder_2.dtype
244
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
245
+
246
+ _, seq_len, _ = prompt_embeds.shape
247
+
248
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
249
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
250
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
251
+
252
+ return prompt_embeds
253
+
254
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
255
+ def _get_clip_prompt_embeds(
256
+ self,
257
+ prompt: Union[str, List[str]],
258
+ num_images_per_prompt: int = 1,
259
+ device: Optional[torch.device] = None,
260
+ ):
261
+ device = self.text_encoder.device
262
+
263
+ prompt = [prompt] if isinstance(prompt, str) else prompt
264
+ batch_size = len(prompt)
265
+
266
+ # if isinstance(self, TextualInversionLoaderMixin):
267
+ # prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
268
+
269
+ text_inputs = self.tokenizer(
270
+ prompt,
271
+ padding="max_length",
272
+ max_length=self.tokenizer_max_length,
273
+ truncation=True,
274
+ return_overflowing_tokens=False,
275
+ return_length=False,
276
+ return_tensors="pt",
277
+ )
278
+
279
+ text_input_ids = text_inputs.input_ids
280
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
281
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
282
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
283
+ logger.warning(
284
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
285
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
286
+ )
287
+ prompt_embeds = self.text_encoder(text_input_ids.to(self.text_encoder.device), output_hidden_states=False)
288
+
289
+ # Use pooled output of CLIPTextModel
290
+ prompt_embeds = prompt_embeds.pooler_output
291
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
292
+
293
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
294
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
295
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
296
+
297
+ return prompt_embeds
298
+
299
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
300
+ def encode_prompt(
301
+ self,
302
+ prompt: Union[str, List[str]],
303
+ prompt_2: Union[str, List[str]],
304
+ device: Optional[torch.device] = None,
305
+ num_images_per_prompt: int = 1,
306
+ prompt_embeds: Optional[torch.FloatTensor] = None,
307
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
308
+ max_sequence_length: int = 512,
309
+ lora_scale: Optional[float] = None,
310
+ ):
311
+ r"""
312
+
313
+ Args:
314
+ prompt (`str` or `List[str]`, *optional*):
315
+ prompt to be encoded
316
+ prompt_2 (`str` or `List[str]`, *optional*):
317
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
318
+ used in all text-encoders
319
+ device: (`torch.device`):
320
+ torch device
321
+ num_images_per_prompt (`int`):
322
+ number of images that should be generated per prompt
323
+ prompt_embeds (`torch.FloatTensor`, *optional*):
324
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
325
+ provided, text embeddings will be generated from `prompt` input argument.
326
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
327
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
328
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
329
+ lora_scale (`float`, *optional*):
330
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
331
+ """
332
+ device = self.text_encoder.device
333
+ # set lora scale so that monkey patched LoRA
334
+ # function of text encoder can correctly access it
335
+ if lora_scale is not None:
336
+ self._lora_scale = lora_scale
337
+
338
+ # dynamically adjust the LoRA scale
339
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
340
+ scale_lora_layers(self.text_encoder, lora_scale)
341
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
342
+ scale_lora_layers(self.text_encoder_2, lora_scale)
343
+
344
+ prompt = [prompt] if isinstance(prompt, str) else prompt
345
+
346
+ if prompt_embeds is None:
347
+ prompt_2 = prompt_2 or prompt
348
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
349
+
350
+ # We only use the pooled prompt output from the CLIPTextModel
351
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
352
+ prompt=prompt,
353
+ device=device,
354
+ num_images_per_prompt=num_images_per_prompt,
355
+ )
356
+ prompt_embeds = self._get_t5_prompt_embeds(
357
+ prompt=prompt_2,
358
+ num_images_per_prompt=num_images_per_prompt,
359
+ max_sequence_length=max_sequence_length,
360
+ device=device,
361
+ )
362
+
363
+
364
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
365
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
366
+
367
+
368
+
369
+ return prompt_embeds, pooled_prompt_embeds, text_ids
370
+
371
+
372
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
373
+ if isinstance(generator, list):
374
+ image_latents = [
375
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
376
+ for i in range(image.shape[0])
377
+ ]
378
+ image_latents = torch.cat(image_latents, dim=0)
379
+ else:
380
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
381
+
382
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
383
+
384
+ return image_latents
385
+ def prepare_latents(self, image, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
386
+ if isinstance(generator, list) and len(generator) != batch_size:
387
+ raise ValueError(
388
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
389
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
390
+ )
391
+
392
+ # VAE applies 8x compression on images but we must also account for packing which requires
393
+ # latent height and width to be divisible by 2.
394
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
395
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
396
+ shape = (batch_size, num_channels_latents, height, width)
397
+
398
+ image_latents = image_ids = None
399
+ if image is not None:
400
+ image = image.to(device=device, dtype=dtype)
401
+ if image.shape[1] != self.latent_channels:
402
+ image_latents = self._encode_vae_image(image=image, generator=generator)
403
+ else:
404
+ image_latents = image
405
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
406
+ # expand init_latents for batch_size
407
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
408
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
409
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
410
+ raise ValueError(
411
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
412
+ )
413
+ else:
414
+ image_latents = torch.cat([image_latents], dim=0)
415
+
416
+ image_latent_height, image_latent_width = image_latents.shape[2:]
417
+ image_latents = self._pack_latents(
418
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
419
+ )
420
+ image_ids = self._prepare_latent_image_ids(
421
+ batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
422
+ )
423
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
424
+ image_ids[..., 0] = 1
425
+
426
+ latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
427
+
428
+ if latents is None:
429
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
430
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
431
+ else:
432
+ latents = latents.to(device=device, dtype=dtype)
433
+
434
+ return latents, image_latents, latent_ids, image_ids
435
+
436
+ def forward(
437
+ self, image: PipelineImageInput = None, prompt: Optional[str] = None, prompt_2: Optional[str] = None,
438
+ negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None,
439
+ height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28,
440
+ guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1,
441
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
442
+ latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil",
443
+ return_dict: bool = True, max_area: int = 1024**2, _auto_resize: bool = True,
444
+ **kwargs
445
+ ):
446
+ joint_attention_kwargs = kwargs.get("joint_attention_kwargs")
447
+ prompt_embeds = kwargs.get("prompt_embeds")
448
+ pooled_prompt_embeds = kwargs.get("pooled_prompt_embeds")
449
+ negative_prompt_embeds = kwargs.get("negative_prompt_embeds")
450
+ negative_pooled_prompt_embeds = kwargs.get("negative_pooled_prompt_embeds")
451
+ ip_adapter_image = kwargs.get("ip_adapter_image")
452
+ ip_adapter_image_embeds = kwargs.get("ip_adapter_image_embeds")
453
+ negative_ip_adapter_image = kwargs.get("negative_ip_adapter_image")
454
+ negative_ip_adapter_image_embeds = kwargs.get("negative_ip_adapter_image_embeds")
455
+ callback_on_step_end = kwargs.get("callback_on_step_end")
456
+ callback_on_step_end_tensor_inputs = kwargs.get("callback_on_step_end_tensor_inputs", ["latents"])
457
+ max_sequence_length = kwargs.get("max_sequence_length", 512)
458
+
459
+ sigmas = kwargs.get("sigmas")
460
+ height, width = image.shape[2], image.shape[3]
461
+ # height = height or self.default_sample_size * self.vae_scale_factor
462
+ # width = width or self.default_sample_size * self.vae_scale_factor
463
+ original_height, original_width = height, width
464
+ aspect_ratio = width / height
465
+ # width = round((max_area * aspect_ratio) ** 0.5)
466
+ # height = round((max_area / aspect_ratio) ** 0.5)
467
+ multiple_of = self.vae_scale_factor * 2
468
+ width = width // multiple_of * multiple_of
469
+ height = height // multiple_of * multiple_of
470
+ if height != original_height or width != original_width:
471
+ logger.warning(f"Resizing to {height}x{width} to fit model requirements.")
472
+
473
+ prompt = 'please remove the reflection in this image'
474
+ prompt_2 = 'please remove the reflection in this image'
475
+
476
+ if self.text_encoder_2 is not None:
477
+ self.text_encoder.max_position_embeddings = 77
478
+ self.text_encoder_2.max_position_embeddings = 512
479
+ self.check_inputs(
480
+ prompt, prompt_2, height, width, negative_prompt, negative_prompt_2,
481
+ prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds,
482
+ negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs
483
+ )
484
+ self._joint_attention_kwargs = joint_attention_kwargs
485
+ self._interrupt = False
486
+
487
+ if prompt is not None and isinstance(prompt, str): batch_size = 1
488
+ elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt)
489
+ else: batch_size = prompt_embeds.shape[0]
490
+ device = self.text_encoder.device
491
+ lora_scale = self.joint_attention_kwargs.get("scale") if self.joint_attention_kwargs is not None else None
492
+
493
+ do_classifier_free_guidance = guidance_scale > 1.0
494
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.cached_prompt_embeds, self.cached_pooled_prompt_embeds, self.cached_text_ids
495
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
496
+ img = image[0] if isinstance(image, list) else image
497
+ image_height, image_width = image.shape[2], image.shape[3]
498
+ image = self.image_processor.resize(image, image_height, image_width)
499
+ image = self.image_processor.preprocess(image, image_height, image_width)
500
+
501
+ num_channels_latents = self.transformer.config.in_channels // 4
502
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
503
+ image, batch_size * num_images_per_prompt, num_channels_latents, height, width,
504
+ prompt_embeds.dtype, device, generator, latents
505
+ )
506
+ if image_ids is not None:
507
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0)
508
+
509
+ mu = calculate_shift(latents.shape[1], self.scheduler.config.get("base_image_seq_len", 256),
510
+ self.scheduler.config.get("max_image_seq_len", 4096),
511
+ self.scheduler.config.get("base_shift", 0.5),
512
+ self.scheduler.config.get("max_shift", 1.15))
513
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
514
+
515
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
516
+ self._num_timesteps = len(timesteps)
517
+
518
+
519
+ if self.transformer.config.guidance_embeds:
520
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
521
+ guidance = guidance.expand(latents.shape[0])
522
+ else:
523
+ guidance = None
524
+
525
+ if self.joint_attention_kwargs is None:
526
+ self._joint_attention_kwargs = {}
527
+ self.scheduler.set_begin_index(0)
528
+ for i, t in self.progress_bar(enumerate(timesteps)):
529
+ if self.interrupt: break
530
+
531
+ self._current_timestep = t
532
+
533
+
534
+ latent_model_input = latents
535
+
536
+ if image_latents is not None:
537
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
538
+
539
+
540
+
541
+ noise_pred = self.transformer(
542
+ hidden_states=latent_model_input,
543
+ timestep=t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) / 1000,
544
+ guidance=guidance,
545
+ pooled_projections=pooled_prompt_embeds,
546
+ encoder_hidden_states=prompt_embeds,
547
+ txt_ids=text_ids,
548
+ img_ids=latent_ids,
549
+ joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False
550
+ )[0][:, :latents.size(1)]
551
+
552
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
553
+
554
+ if callback_on_step_end is not None:
555
+ callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
556
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
557
+ latents = callback_outputs.pop("latents", latents)
558
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
559
+
560
+
561
+
562
+ if output_type == "latent":
563
+ image = latents
564
+ else:
565
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
566
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
567
+ image = self.vae.decode(latents, return_dict=False)[0]
568
+ # image = self.image_processor.postprocess(image, output_type=output_type)
569
+
570
+ self.maybe_free_model_hooks()
571
+ if not return_dict: return (image,)
572
+ # self.output = image
573
+ return (image + 1) / 2
574
+
575
+ # @staticmethod
576
+ def encode_image(self, images: torch.Tensor):
577
+ """
578
+ Encodes the images into tokens and ids for FLUX pipeline.
579
+ """
580
+ images = self.image_processor.preprocess(images)
581
+ images = images.to(self.text_encoder.device).to(self.dtype)
582
+ images = self.vae.encode(images).latent_dist.sample()
583
+ images = (
584
+ images - self.vae.config.shift_factor
585
+ ) * self.vae.config.scaling_factor
586
+ images_tokens = self._pack_latents(images, *images.shape)
587
+ images_ids = self._prepare_latent_image_ids(
588
+ images.shape[0],
589
+ images.shape[2],
590
+ images.shape[3],
591
+ self.text_encoder.device,
592
+ self.dtype,
593
+ )
594
+ if images_tokens.shape[1] != images_ids.shape[0]:
595
+ images_ids = self._prepare_latent_image_ids(
596
+ images.shape[0],
597
+ images.shape[2] // 2,
598
+ images.shape[3] // 2,
599
+ self.text_encoder.device,
600
+ self.dtype,
601
+ )
602
+ return images_tokens, images_ids
603
+
604
+ if __name__ == "__main__":
605
+ with torch.no_grad():
606
+ from PIL import Image
607
+ opt = {
608
+ "model": "/home/s1023244038/kontext/",
609
+ }
610
+ model = FluxModel(opt)
611
+
612
+ image = Image.open("/home/s1023244038/sirs/test/Nature/blended/1_143.jpg")
613
+ prompt = ""
614
+ prompt_2 = ""
615
+ out = model(image=image, prompt=prompt, prompt_2=prompt_2)
616
+
617
+ out[0].save("output.png")
pooled_prompt_embeds.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22024a113602308a2d1ef8ff9bd937ba1cc4f6a69a4ba48310ae17d9e575785b
3
+ size 2781
prompt_embeds.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5a264ff538cfaa7aa3541178ecf32ffaabbd35d7d84d87145e27d4e299d70eb
3
+ size 4195514
requirements.txt CHANGED
@@ -2,4 +2,5 @@ transformers
2
  git+https://github.com/huggingface/diffusers.git
3
  accelerate
4
  safetensors
5
- sentencepiece
 
 
2
  git+https://github.com/huggingface/diffusers.git
3
  accelerate
4
  safetensors
5
+ sentencepiece
6
+ peft
text_ids.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f15f9a0720f276974ac1b8860952fa8d2777685743694334f939dc83b506e44
3
+ size 4257