Badnerle1234454 commited on
Commit
dde8ccd
·
verified ·
1 Parent(s): 4ba5be8

Upload 2 files

Browse files
Files changed (2) hide show
  1. comfyui/README.md +40 -0
  2. comfyui/withanyone_node.py +353 -0
comfyui/README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WithAnyone ComfyUI Node
2
+
3
+ This folder contains a custom ComfyUI node that wraps the WithAnyone inference pipeline.
4
+
5
+ ## Installation Steps
6
+
7
+ 1. Copy (or symlink) the file `withanyone_node.py` into your ComfyUI `custom_nodes/` directory.
8
+ 2. Make sure the WithAnyone project (this repository) and all dependencies from `requirements.txt` are installed in the same Python environment you use to run ComfyUI.
9
+ 3. Download the required model checkpoints (Flux, SigLIP, CLIP, T5, WithAnyone IPA weights) to your local machine or rely on Hugging Face automatic downloads. Update the node inputs if you store them in custom locations.
10
+ 4. Launch ComfyUI. The node appears in the *withanyone* category as **WithAnyone (Flux)**.
11
+
12
+ ## Node Inputs
13
+
14
+ | Input | Description |
15
+ | --- | --- |
16
+ | `prompt` | Text prompt used during generation. |
17
+ | `ref_images` | One or more reference portraits (tensor input from ComfyUI). Faces are detected via InsightFace. |
18
+ | `manual_bboxes` | Optional bounding boxes (`x1,y1,x2,y2` separated by semicolons or JSON) that place each identity in the final image. Leave empty to use default layouts. |
19
+ | `width` / `height` | Output resolution (multiples of 16 recommended). |
20
+ | `num_steps` | Number of diffusion steps. |
21
+ | `guidance` | Guidance scale (CFG). |
22
+ | `seed` | Random seed for reproducibility. |
23
+ | `model_type` | Underlying Flux backbone (`flux-dev`, `flux-dev-fp8`, `flux-schnell`). |
24
+ | `id_weight` / `siglip_weight` | Weights for identity preservation vs. semantic alignment. |
25
+ | `only_lora`, `offload`, `lora_rank`, `lora_weight`, `additional_lora` | Advanced controls for LoRA usage. |
26
+ | `ipa_path`, `clip_path`, `t5_path`, `flux_path`, `siglip_path` | Paths or Hugging Face identifiers for checkpoints. |
27
+
28
+ ## Outputs
29
+
30
+ The node returns:
31
+
32
+ - Generated image as a ComfyUI `IMAGE`.
33
+ - Metadata dictionary containing seed, resolution, steps, guidance, bounding boxes, and model type.
34
+
35
+ ## Notes
36
+
37
+ - The node reuses the project’s `FaceExtractor` to obtain ArcFace embeddings from the provided references. If a face cannot be detected, the node raises an error.
38
+ - When no bounding boxes are supplied, a default layout is chosen based on the number of reference faces. The defaults are designed for 512×512 images; the node scales them to the requested resolution.
39
+ - For background matting or more advanced mask handling, extend the node to leverage WithAnyone’s matting utilities.
40
+
comfyui/withanyone_node.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ComfyUI custom node for running the WithAnyone pipeline.
3
+
4
+ Copy or symlink this file into your ComfyUI `custom_nodes` directory and ensure
5
+ the WithAnyone project plus its dependencies are available in the Python path.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import random
12
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
13
+
14
+ import numpy as np
15
+ import torch
16
+ from PIL import Image
17
+
18
+ try:
19
+ from comfy import model_management
20
+ from comfy.utils import ProgressBar
21
+ except ImportError: # pragma: no cover - only executed outside ComfyUI
22
+ model_management = None # type: ignore
23
+ ProgressBar = None # type: ignore
24
+
25
+ from withanyone.flux.pipeline import WithAnyonePipeline
26
+ from util import FaceExtractor
27
+
28
+
29
+ DEFAULT_SINGLE_BBOXES: List[List[int]] = [
30
+ [150, 100, 250, 200],
31
+ [100, 100, 200, 200],
32
+ [200, 100, 300, 200],
33
+ [250, 100, 350, 200],
34
+ [300, 100, 400, 200],
35
+ ]
36
+
37
+ DEFAULT_DOUBLE_BBOXES: List[List[List[int]]] = [
38
+ [[100, 100, 200, 200], [300, 100, 400, 200]],
39
+ [[150, 100, 250, 200], [300, 100, 400, 200]],
40
+ ]
41
+
42
+ PIPELINE_CACHE: Dict[Tuple, WithAnyonePipeline] = {}
43
+ FACE_EXTRACTOR: Optional[FaceExtractor] = None
44
+
45
+
46
+ def _get_device() -> torch.device:
47
+ if model_management is not None:
48
+ return model_management.get_torch_device()
49
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+
51
+
52
+ def _get_face_extractor() -> FaceExtractor:
53
+ global FACE_EXTRACTOR
54
+ if FACE_EXTRACTOR is None:
55
+ FACE_EXTRACTOR = FaceExtractor()
56
+ return FACE_EXTRACTOR
57
+
58
+
59
+ def _select_default_bboxes(identity_count: int) -> List[List[int]]:
60
+ if identity_count >= 2:
61
+ return [*random.choice(DEFAULT_DOUBLE_BBOXES)]
62
+ return [*DEFAULT_SINGLE_BBOXES[random.randrange(len(DEFAULT_SINGLE_BBOXES))]]
63
+
64
+
65
+ def _parse_manual_bboxes(spec: str) -> Optional[List[List[int]]]:
66
+ if not spec or not spec.strip():
67
+ return None
68
+
69
+ spec = spec.strip()
70
+ try:
71
+ parsed = json.loads(spec)
72
+ except json.JSONDecodeError:
73
+ parsed = []
74
+ for chunk in spec.split(";"):
75
+ chunk = chunk.strip()
76
+ if not chunk:
77
+ continue
78
+ values = [float(value.strip()) for value in chunk.split(",")]
79
+ if len(values) != 4:
80
+ raise ValueError(f"Expected 4 values per bbox, got {len(values)}: {chunk}")
81
+ parsed.append(values)
82
+
83
+ if isinstance(parsed, dict) and "bboxes" in parsed:
84
+ parsed = parsed["bboxes"]
85
+
86
+ if not isinstance(parsed, Sequence):
87
+ raise ValueError("Bounding box specification must be a list or dictionary with 'bboxes'.")
88
+
89
+ cleaned: List[List[int]] = []
90
+ for entry in parsed:
91
+ if isinstance(entry, str):
92
+ coords = [float(value.strip()) for value in entry.split(",")]
93
+ elif isinstance(entry, Iterable):
94
+ coords = [float(value) for value in entry]
95
+ else:
96
+ raise ValueError(f"Unsupported bbox entry type: {type(entry)}")
97
+
98
+ if len(coords) != 4:
99
+ raise ValueError(f"Each bbox needs four coordinates, received {coords}")
100
+
101
+ cleaned.append([int(round(coord)) for coord in coords])
102
+
103
+ return cleaned
104
+
105
+
106
+ def _scale_bboxes(bboxes: List[List[int]], width: int, height: int, reference: int = 512) -> List[List[int]]:
107
+ if width == reference and height == reference:
108
+ return bboxes
109
+
110
+ sx = width / float(reference)
111
+ sy = height / float(reference)
112
+ scaled = []
113
+ for x1, y1, x2, y2 in bboxes:
114
+ scaled.append(
115
+ [
116
+ int(round(x1 * sx)),
117
+ int(round(y1 * sy)),
118
+ int(round(x2 * sx)),
119
+ int(round(y2 * sy)),
120
+ ]
121
+ )
122
+ return scaled
123
+
124
+
125
+ def _comfy_to_pil_batch(images: torch.Tensor) -> List[Image.Image]:
126
+ if images.ndim == 3:
127
+ images = images.unsqueeze(0)
128
+ pil_images: List[Image.Image] = []
129
+ for image in images:
130
+ array = image.detach().cpu().numpy()
131
+ if array.dtype != np.float32 and array.dtype != np.float64:
132
+ array = array.astype(np.float32)
133
+ array = np.clip(array, 0.0, 1.0)
134
+ array = (array * 255.0).astype(np.uint8)
135
+ if array.shape[-1] == 4:
136
+ array = array[..., :3]
137
+ pil_images.append(Image.fromarray(array))
138
+ return pil_images
139
+
140
+
141
+ def _pil_to_comfy_image(image: Image.Image) -> torch.Tensor:
142
+ array = np.asarray(image.convert("RGB"), dtype=np.float32) / 255.0
143
+ tensor = torch.from_numpy(array)
144
+ tensor = tensor.unsqueeze(0) # batch dimension
145
+ return tensor
146
+
147
+
148
+ def _prepare_references(
149
+ images: torch.Tensor,
150
+ device: torch.device,
151
+ ) -> Tuple[List[Image.Image], torch.Tensor]:
152
+ face_extractor = _get_face_extractor()
153
+ ref_pil: List[Image.Image] = []
154
+ arc_embeddings: List[torch.Tensor] = []
155
+
156
+ for pil_image in _comfy_to_pil_batch(images):
157
+ ref_img, embedding = face_extractor.extract(pil_image)
158
+ if ref_img is None or embedding is None:
159
+ raise RuntimeError("Failed to extract a face embedding from the provided reference image.")
160
+ ref_pil.append(ref_img)
161
+ arc_embeddings.append(torch.tensor(embedding, dtype=torch.float32, device=device))
162
+
163
+ arcface_tensor = torch.stack(arc_embeddings, dim=0)
164
+ return ref_pil, arcface_tensor
165
+
166
+
167
+ def _get_pipeline(
168
+ model_type: str,
169
+ ipa_path: str,
170
+ clip_path: str,
171
+ t5_path: str,
172
+ flux_path: str,
173
+ siglip_path: str,
174
+ only_lora: bool,
175
+ offload: bool,
176
+ lora_rank: int,
177
+ lora_weight: float,
178
+ additional_lora: Optional[str],
179
+ ) -> WithAnyonePipeline:
180
+ device = _get_device()
181
+ cache_key = (
182
+ model_type,
183
+ ipa_path,
184
+ clip_path,
185
+ t5_path,
186
+ flux_path,
187
+ siglip_path,
188
+ only_lora,
189
+ offload,
190
+ lora_rank,
191
+ lora_weight,
192
+ additional_lora,
193
+ device.type,
194
+ )
195
+
196
+ pipeline = PIPELINE_CACHE.get(cache_key)
197
+ if pipeline is None:
198
+ face_extractor = _get_face_extractor()
199
+ pipeline = WithAnyonePipeline(
200
+ model_type=model_type,
201
+ ipa_path=ipa_path,
202
+ device=device,
203
+ offload=offload,
204
+ only_lora=only_lora,
205
+ lora_rank=lora_rank,
206
+ face_extractor=face_extractor,
207
+ additional_lora_ckpt=additional_lora,
208
+ lora_weight=lora_weight,
209
+ clip_path=clip_path,
210
+ t5_path=t5_path,
211
+ flux_path=flux_path,
212
+ siglip_path=siglip_path,
213
+ )
214
+ PIPELINE_CACHE[cache_key] = pipeline
215
+ else:
216
+ pipeline.device = device
217
+
218
+ return pipeline
219
+
220
+
221
+ class WithAnyoneNode:
222
+ """
223
+ ComfyUI node that wraps the WithAnyone inference pipeline.
224
+ """
225
+
226
+ @classmethod
227
+ def INPUT_TYPES(cls): # noqa: N802 - ComfyUI API
228
+ return {
229
+ "required": {
230
+ "prompt": ("STRING", {"multiline": True, "default": ""}),
231
+ "ref_images": ("IMAGE",),
232
+ },
233
+ "optional": {
234
+ "manual_bboxes": ("STRING", {"default": ""}),
235
+ "width": ("INT", {"default": 512, "min": 256, "max": 1024, "step": 16}),
236
+ "height": ("INT", {"default": 512, "min": 256, "max": 1024, "step": 16}),
237
+ "num_steps": ("INT", {"default": 25, "min": 5, "max": 100, "step": 1}),
238
+ "guidance": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 25.0, "step": 0.1}),
239
+ "seed": ("INT", {"default": 1234, "min": 0, "max": 2**32 - 1}),
240
+ "model_type": (["flux-dev", "flux-dev-fp8", "flux-schnell"], {"default": "flux-dev"}),
241
+ "id_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.05}),
242
+ "siglip_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.05}),
243
+ "only_lora": ("BOOLEAN", {"default": True}),
244
+ "offload": ("BOOLEAN", {"default": False}),
245
+ "lora_rank": ("INT", {"default": 64, "min": 1, "max": 128, "step": 1}),
246
+ "lora_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05}),
247
+ "additional_lora": ("STRING", {"default": ""}),
248
+ "ipa_path": ("STRING", {"default": "WithAnyone/WithAnyone"}),
249
+ "clip_path": ("STRING", {"default": "openai/clip-vit-large-patch14"}),
250
+ "t5_path": ("STRING", {"default": "xlabs-ai/xflux_text_encoders"}),
251
+ "flux_path": ("STRING", {"default": "black-forest-labs/FLUX.1-dev"}),
252
+ "siglip_path": ("STRING", {"default": "google/siglip-base-patch16-256-i18n"}),
253
+ },
254
+ }
255
+
256
+ RETURN_TYPES = ("IMAGE", "DICT")
257
+ RETURN_NAMES = ("image", "info")
258
+ FUNCTION = "generate"
259
+ CATEGORY = "withanyone"
260
+
261
+ def _create_progress_bar(self, steps: int):
262
+ if ProgressBar is None:
263
+ return None
264
+ return ProgressBar(steps)
265
+
266
+ def generate( # noqa: C901 - ComfyUI entry points are typically long
267
+ self,
268
+ prompt: str,
269
+ ref_images: torch.Tensor,
270
+ manual_bboxes: str = "",
271
+ width: int = 512,
272
+ height: int = 512,
273
+ num_steps: int = 25,
274
+ guidance: float = 4.0,
275
+ seed: int = 1234,
276
+ model_type: str = "flux-dev",
277
+ id_weight: float = 1.0,
278
+ siglip_weight: float = 1.0,
279
+ only_lora: bool = True,
280
+ offload: bool = False,
281
+ lora_rank: int = 64,
282
+ lora_weight: float = 1.0,
283
+ additional_lora: str = "",
284
+ ipa_path: str = "WithAnyone/WithAnyone",
285
+ clip_path: str = "openai/clip-vit-large-patch14",
286
+ t5_path: str = "xlabs-ai/xflux_text_encoders",
287
+ flux_path: str = "black-forest-labs/FLUX.1-dev",
288
+ siglip_path: str = "google/siglip-base-patch16-256-i18n",
289
+ ):
290
+ additional_lora_ckpt = additional_lora if additional_lora.strip() else None
291
+ device = _get_device()
292
+ progress = self._create_progress_bar(num_steps)
293
+
294
+ pipeline = _get_pipeline(
295
+ model_type=model_type,
296
+ ipa_path=ipa_path,
297
+ clip_path=clip_path,
298
+ t5_path=t5_path,
299
+ flux_path=flux_path,
300
+ siglip_path=siglip_path,
301
+ only_lora=only_lora,
302
+ offload=offload,
303
+ lora_rank=lora_rank,
304
+ lora_weight=lora_weight,
305
+ additional_lora=additional_lora_ckpt,
306
+ )
307
+
308
+ ref_imgs_pil, arcface_embeddings = _prepare_references(ref_images, device=device)
309
+
310
+ parsed_bboxes = _parse_manual_bboxes(manual_bboxes)
311
+ if parsed_bboxes is None:
312
+ parsed_bboxes = _select_default_bboxes(len(ref_imgs_pil))
313
+ parsed_bboxes = _scale_bboxes(parsed_bboxes, width, height)
314
+
315
+ result_image = pipeline(
316
+ prompt=prompt,
317
+ width=width,
318
+ height=height,
319
+ guidance=guidance,
320
+ num_steps=num_steps,
321
+ seed=seed,
322
+ ref_imgs=ref_imgs_pil,
323
+ arcface_embeddings=arcface_embeddings,
324
+ bboxes=[parsed_bboxes],
325
+ id_weight=id_weight,
326
+ siglip_weight=siglip_weight,
327
+ )
328
+
329
+ if progress is not None:
330
+ progress.update_absolute(num_steps, num_steps)
331
+
332
+ output_tensor = _pil_to_comfy_image(result_image)
333
+ info = {
334
+ "seed": seed,
335
+ "width": width,
336
+ "height": height,
337
+ "guidance": guidance,
338
+ "num_steps": num_steps,
339
+ "bboxes": parsed_bboxes,
340
+ "model_type": model_type,
341
+ }
342
+
343
+ return output_tensor, info
344
+
345
+
346
+ NODE_CLASS_MAPPINGS = {
347
+ "WithAnyoneGenerate": WithAnyoneNode,
348
+ }
349
+
350
+ NODE_DISPLAY_NAME_MAPPINGS = {
351
+ "WithAnyoneGenerate": "WithAnyone (Flux)",
352
+ }
353
+