Maxclon commited on
Commit
7480dd0
·
1 Parent(s): c4206ff

Update pulidflux.py

Browse files
Files changed (1) hide show
  1. pulidflux.py +148 -15
pulidflux.py CHANGED
@@ -12,6 +12,8 @@ from insightface.app import FaceAnalysis
12
  from facexlib.parsing import init_parsing_model
13
  from facexlib.utils.face_restoration_helper import FaceRestoreHelper
14
 
 
 
15
  from .eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
16
  from .encoders_flux import IDFormer, PerceiverAttentionCA
17
 
@@ -24,6 +26,8 @@ else:
24
  current_paths, _ = folder_paths.folder_names_and_paths["pulid"]
25
  folder_paths.folder_names_and_paths["pulid"] = (current_paths, folder_paths.supported_pt_extensions)
26
 
 
 
27
  class PulidFluxModel(nn.Module):
28
  def __init__(self):
29
  super().__init__()
@@ -72,7 +76,12 @@ def forward_orig(
72
  y: Tensor,
73
  guidance: Tensor = None,
74
  control=None,
 
 
 
75
  ) -> Tensor:
 
 
76
  if img.ndim != 3 or txt.ndim != 3:
77
  raise ValueError("Input img and txt tensors must have 3 dimensions.")
78
 
@@ -91,8 +100,32 @@ def forward_orig(
91
  pe = self.pe_embedder(ids)
92
 
93
  ca_idx = 0
 
94
  for i, block in enumerate(self.double_blocks):
95
- img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  if control is not None: # Controlnet
98
  control_i = control.get("input")
@@ -106,14 +139,34 @@ def forward_orig(
106
  if i % self.pulid_double_interval == 0:
107
  # Will calculate influence of all pulid nodes at once
108
  for _, node_data in self.pulid_data.items():
109
- if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])):
 
 
 
 
 
110
  img = img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], img)
111
  ca_idx += 1
112
 
113
  img = torch.cat((txt, img), 1)
114
-
115
  for i, block in enumerate(self.single_blocks):
116
- img = block(img, vec=vec, pe=pe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  if control is not None: # Controlnet
119
  control_o = control.get("output")
@@ -122,13 +175,20 @@ def forward_orig(
122
  if add is not None:
123
  img[:, txt.shape[1] :, ...] += add
124
 
 
125
  # PuLID attention
126
  if self.pulid_data:
127
  real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...]
128
  if i % self.pulid_single_interval == 0:
129
  # Will calculate influence of all nodes at once
130
  for _, node_data in self.pulid_data.items():
131
- if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])):
 
 
 
 
 
 
132
  real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], real_img)
133
  ca_idx += 1
134
  img = torch.cat((txt, real_img), 1)
@@ -148,6 +208,29 @@ def image_to_tensor(image):
148
  tensor = tensor[..., [2, 1, 0]]
149
  return tensor
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def to_gray(img):
152
  x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
153
  x = x.repeat(1, 3, 1, 1)
@@ -227,7 +310,7 @@ class PulidFluxEvaClipLoader:
227
 
228
  class ApplyPulidFlux:
229
  @classmethod
230
- def INPUT_TYPES(s):
231
  return {
232
  "required": {
233
  "model": ("MODEL", ),
@@ -238,9 +321,15 @@ class ApplyPulidFlux:
238
  "weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }),
239
  "start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }),
240
  "end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }),
 
 
 
 
 
241
  },
242
  "optional": {
243
  "attn_mask": ("MASK", ),
 
244
  },
245
  "hidden": {
246
  "unique_id": "UNIQUE_ID"
@@ -254,15 +343,13 @@ class ApplyPulidFlux:
254
  def __init__(self):
255
  self.pulid_data_dict = None
256
 
257
- def apply_pulid_flux(self, model, pulid_flux, eva_clip, face_analysis, image, weight, start_at, end_at, attn_mask=None, unique_id=None):
258
  device = comfy.model_management.get_torch_device()
259
  # Why should I care what args say, when the unet model has a different dtype?!
260
  # Am I missing something?!
261
  #dtype = comfy.model_management.unet_dtype()
262
  dtype = model.model.diffusion_model.dtype
263
- # Because of 8bit models we must check what cast type does the unet uses
264
- # ZLUDA (Intel, AMD) & GPUs with compute capability < 8.0 don't support bfloat16 etc.
265
- # Issue: https://github.com/balazik/ComfyUI-PuLID-Flux/issues/6
266
  if model.model.manual_cast_dtype is not None:
267
  dtype = model.model.manual_cast_dtype
268
 
@@ -277,6 +364,9 @@ class ApplyPulidFlux:
277
  attn_mask = attn_mask.unsqueeze(0)
278
  attn_mask = attn_mask.to(device, dtype=dtype)
279
 
 
 
 
280
  image = tensor_to_image(image)
281
 
282
  face_helper = FaceRestoreHelper(
@@ -333,7 +423,11 @@ class ApplyPulidFlux:
333
  bg = sum(parsing_out == i for i in bg_label).bool()
334
  white_image = torch.ones_like(align_face)
335
  # Only keep the face features
336
- face_features_image = torch.where(bg, white_image, to_gray(align_face))
 
 
 
 
337
 
338
  # Transform img before sending to eva_clip
339
  # Apparently MPS only supports NEAREST interpolation?
@@ -359,10 +453,49 @@ class ApplyPulidFlux:
359
  logging.warning("PuLID warning: No faces detected in any of the given images, returning unmodified model.")
360
  return (model,)
361
 
362
- # average embeddings
363
- cond = torch.cat(cond).to(device, dtype=dtype)
364
- if cond.shape[0] > 1:
365
- cond = torch.mean(cond, dim=0, keepdim=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
  sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at)
368
  sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at)
 
12
  from facexlib.parsing import init_parsing_model
13
  from facexlib.utils.face_restoration_helper import FaceRestoreHelper
14
 
15
+ import torch.nn.functional as F
16
+
17
  from .eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
18
  from .encoders_flux import IDFormer, PerceiverAttentionCA
19
 
 
26
  current_paths, _ = folder_paths.folder_names_and_paths["pulid"]
27
  folder_paths.folder_names_and_paths["pulid"] = (current_paths, folder_paths.supported_pt_extensions)
28
 
29
+ from .online_train2 import online_train
30
+
31
  class PulidFluxModel(nn.Module):
32
  def __init__(self):
33
  super().__init__()
 
76
  y: Tensor,
77
  guidance: Tensor = None,
78
  control=None,
79
+ transformer_options={},
80
+ attn_mask: Tensor = None,
81
+ **kwargs # so it won't break if we add more stuff in the future
82
  ) -> Tensor:
83
+ patches_replace = transformer_options.get("patches_replace", {})
84
+
85
  if img.ndim != 3 or txt.ndim != 3:
86
  raise ValueError("Input img and txt tensors must have 3 dimensions.")
87
 
 
100
  pe = self.pe_embedder(ids)
101
 
102
  ca_idx = 0
103
+ blocks_replace = patches_replace.get("dit", {})
104
  for i, block in enumerate(self.double_blocks):
105
+ if ("double_block", i) in blocks_replace:
106
+ def block_wrap(args):
107
+ out = {}
108
+ out["img"], out["txt"] = block(img=args["img"],
109
+ txt=args["txt"],
110
+ vec=args["vec"],
111
+ pe=args["pe"],
112
+ attn_mask=args.get("attn_mask"))
113
+ return out
114
+
115
+ out = blocks_replace[("double_block", i)]({"img": img,
116
+ "txt": txt,
117
+ "vec": vec,
118
+ "pe": pe,
119
+ "attn_mask": attn_mask},
120
+ {"original_block": block_wrap})
121
+ txt = out["txt"]
122
+ img = out["img"]
123
+ else:
124
+ img, txt = block(img=img,
125
+ txt=txt,
126
+ vec=vec,
127
+ pe=pe,
128
+ attn_mask=attn_mask)
129
 
130
  if control is not None: # Controlnet
131
  control_i = control.get("input")
 
139
  if i % self.pulid_double_interval == 0:
140
  # Will calculate influence of all pulid nodes at once
141
  for _, node_data in self.pulid_data.items():
142
+ condition_start = node_data['sigma_start'] >= timesteps
143
+ condition_end = timesteps >= node_data['sigma_end']
144
+ condition = torch.logical_and(
145
+ condition_start, condition_end).all()
146
+
147
+ if condition:
148
  img = img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], img)
149
  ca_idx += 1
150
 
151
  img = torch.cat((txt, img), 1)
 
152
  for i, block in enumerate(self.single_blocks):
153
+ if ("single_block", i) in blocks_replace:
154
+ def block_wrap(args):
155
+ out = {}
156
+ out["img"] = block(args["img"],
157
+ vec=args["vec"],
158
+ pe=args["pe"],
159
+ attn_mask=args.get("attn_mask"))
160
+ return out
161
+
162
+ out = blocks_replace[("single_block", i)]({"img": img,
163
+ "vec": vec,
164
+ "pe": pe,
165
+ "attn_mask": attn_mask},
166
+ {"original_block": block_wrap})
167
+ img = out["img"]
168
+ else:
169
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
170
 
171
  if control is not None: # Controlnet
172
  control_o = control.get("output")
 
175
  if add is not None:
176
  img[:, txt.shape[1] :, ...] += add
177
 
178
+
179
  # PuLID attention
180
  if self.pulid_data:
181
  real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...]
182
  if i % self.pulid_single_interval == 0:
183
  # Will calculate influence of all nodes at once
184
  for _, node_data in self.pulid_data.items():
185
+ condition_start = node_data['sigma_start'] >= timesteps
186
+ condition_end = timesteps >= node_data['sigma_end']
187
+
188
+ # Combine conditions and reduce to a single boolean
189
+ condition = torch.logical_and(condition_start, condition_end).all()
190
+
191
+ if condition:
192
  real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], real_img)
193
  ca_idx += 1
194
  img = torch.cat((txt, real_img), 1)
 
208
  tensor = tensor[..., [2, 1, 0]]
209
  return tensor
210
 
211
+ def resize_with_pad(img, target_size): # image: 1, h, w, 3
212
+ img = img.permute(0, 3, 1, 2)
213
+ H, W = target_size
214
+
215
+ h, w = img.shape[2], img.shape[3]
216
+ scale_h = H / h
217
+ scale_w = W / w
218
+ scale = min(scale_h, scale_w)
219
+
220
+ new_h = int(min(h * scale,H))
221
+ new_w = int(min(w * scale,W))
222
+ new_size = (new_h, new_w)
223
+
224
+ img = F.interpolate(img, size=new_size, mode='bicubic', align_corners=False)
225
+
226
+ pad_top = (H - new_h) // 2
227
+ pad_bottom = (H - new_h) - pad_top
228
+ pad_left = (W - new_w) // 2
229
+ pad_right = (W - new_w) - pad_left
230
+ img = F.pad(img, pad=(pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
231
+
232
+ return img.permute(0, 2, 3, 1)
233
+
234
  def to_gray(img):
235
  x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
236
  x = x.repeat(1, 3, 1, 1)
 
310
 
311
  class ApplyPulidFlux:
312
  @classmethod
313
+ def INPUT_TYPES(s):
314
  return {
315
  "required": {
316
  "model": ("MODEL", ),
 
321
  "weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }),
322
  "start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }),
323
  "end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }),
324
+ "fusion": (["mean","concat","max","norm_id","max_token","auto_weight","train_weight"],),
325
+ "fusion_weight_max": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 20.0, "step": 0.1 }),
326
+ "fusion_weight_min": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 20.0, "step": 0.1 }),
327
+ "train_step": ("INT", {"default": 1000, "min": 0, "max": 20000, "step": 1 }),
328
+ "use_gray": ("BOOLEAN", {"default": True, "label_on": "enabled", "label_off": "disabled"}),
329
  },
330
  "optional": {
331
  "attn_mask": ("MASK", ),
332
+ "prior_image": ("IMAGE",), # for train weight, as the target
333
  },
334
  "hidden": {
335
  "unique_id": "UNIQUE_ID"
 
343
  def __init__(self):
344
  self.pulid_data_dict = None
345
 
346
+ def apply_pulid_flux(self, model, pulid_flux, eva_clip, face_analysis, image, weight, start_at, end_at, prior_image=None,fusion="mean", fusion_weight_max=1.0, fusion_weight_min=0.0, train_step=1000, use_gray=True, attn_mask=None, unique_id=None):
347
  device = comfy.model_management.get_torch_device()
348
  # Why should I care what args say, when the unet model has a different dtype?!
349
  # Am I missing something?!
350
  #dtype = comfy.model_management.unet_dtype()
351
  dtype = model.model.diffusion_model.dtype
352
+ # For 8bit use bfloat16 (because ufunc_add_CUDA is not implemented)
 
 
353
  if model.model.manual_cast_dtype is not None:
354
  dtype = model.model.manual_cast_dtype
355
 
 
364
  attn_mask = attn_mask.unsqueeze(0)
365
  attn_mask = attn_mask.to(device, dtype=dtype)
366
 
367
+ if prior_image is not None:
368
+ prior_image = resize_with_pad(prior_image.to(image.device, dtype=image.dtype), target_size=(image.shape[1], image.shape[2]))
369
+ image=torch.cat((prior_image,image),dim=0)
370
  image = tensor_to_image(image)
371
 
372
  face_helper = FaceRestoreHelper(
 
423
  bg = sum(parsing_out == i for i in bg_label).bool()
424
  white_image = torch.ones_like(align_face)
425
  # Only keep the face features
426
+ if use_gray:
427
+ _align_face = to_gray(align_face)
428
+ else:
429
+ _align_face = align_face
430
+ face_features_image = torch.where(bg, white_image, _align_face)
431
 
432
  # Transform img before sending to eva_clip
433
  # Apparently MPS only supports NEAREST interpolation?
 
453
  logging.warning("PuLID warning: No faces detected in any of the given images, returning unmodified model.")
454
  return (model,)
455
 
456
+ # fusion embeddings
457
+ if fusion == "mean":
458
+ cond = torch.cat(cond).to(device, dtype=dtype) # N,32,2048
459
+ if cond.shape[0] > 1:
460
+ cond = torch.mean(cond, dim=0, keepdim=True)
461
+ elif fusion == "concat":
462
+ cond = torch.cat(cond, dim=1).to(device, dtype=dtype)
463
+ elif fusion == "max":
464
+ cond = torch.cat(cond).to(device, dtype=dtype)
465
+ if cond.shape[0] > 1:
466
+ cond = torch.max(cond, dim=0, keepdim=True)[0]
467
+ elif fusion == "norm_id":
468
+ cond = torch.cat(cond).to(device, dtype=dtype)
469
+ if cond.shape[0] > 1:
470
+ norm=torch.norm(cond,dim=(1,2))
471
+ norm=norm/torch.sum(norm)
472
+ cond=torch.einsum("wij,w->ij",cond,norm).unsqueeze(0)
473
+ elif fusion == "max_token":
474
+ cond = torch.cat(cond).to(device, dtype=dtype)
475
+ if cond.shape[0] > 1:
476
+ norm=torch.norm(cond,dim=2)
477
+ _,idx=torch.max(norm,dim=0)
478
+ cond=torch.stack([cond[j,i] for i,j in enumerate(idx)]).unsqueeze(0)
479
+ elif fusion == "auto_weight": # 🤔
480
+ cond = torch.cat(cond).to(device, dtype=dtype)
481
+ if cond.shape[0] > 1:
482
+ norm=torch.norm(cond,dim=2)
483
+ order=torch.argsort(norm,descending=False,dim=0)
484
+ regular_weight=torch.linspace(fusion_weight_min,fusion_weight_max,norm.shape[0]).to(device, dtype=dtype)
485
+
486
+ _cond=[]
487
+ for i in range(cond.shape[1]):
488
+ o=order[:,i]
489
+ _cond.append(torch.einsum('ij,i->j',cond[:,i,:],regular_weight[o]))
490
+ cond=torch.stack(_cond,dim=0).unsqueeze(0)
491
+ elif fusion == "train_weight":
492
+ cond = torch.cat(cond).to(device, dtype=dtype)
493
+ if cond.shape[0] > 1:
494
+ if train_step > 0:
495
+ with torch.inference_mode(False):
496
+ cond = online_train(cond, device=cond.device, step=train_step)
497
+ else:
498
+ cond = torch.mean(cond, dim=0, keepdim=True)
499
 
500
  sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at)
501
  sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at)