""" ComfyUI custom node for running the WithAnyone pipeline. Copy or symlink this file into your ComfyUI `custom_nodes` directory and ensure the WithAnyone project plus its dependencies are available in the Python path. """ from __future__ import annotations import json import random from typing import Dict, Iterable, List, Optional, Sequence, Tuple import numpy as np import torch from PIL import Image try: from comfy import model_management from comfy.utils import ProgressBar except ImportError: # pragma: no cover - only executed outside ComfyUI model_management = None # type: ignore ProgressBar = None # type: ignore from withanyone.flux.pipeline import WithAnyonePipeline from util import FaceExtractor DEFAULT_SINGLE_BBOXES: List[List[int]] = [ [150, 100, 250, 200], [100, 100, 200, 200], [200, 100, 300, 200], [250, 100, 350, 200], [300, 100, 400, 200], ] DEFAULT_DOUBLE_BBOXES: List[List[List[int]]] = [ [[100, 100, 200, 200], [300, 100, 400, 200]], [[150, 100, 250, 200], [300, 100, 400, 200]], ] PIPELINE_CACHE: Dict[Tuple, WithAnyonePipeline] = {} FACE_EXTRACTOR: Optional[FaceExtractor] = None def _get_device() -> torch.device: if model_management is not None: return model_management.get_torch_device() return torch.device("cuda" if torch.cuda.is_available() else "cpu") def _get_face_extractor() -> FaceExtractor: global FACE_EXTRACTOR if FACE_EXTRACTOR is None: FACE_EXTRACTOR = FaceExtractor() return FACE_EXTRACTOR def _select_default_bboxes(identity_count: int) -> List[List[int]]: if identity_count >= 2: return [*random.choice(DEFAULT_DOUBLE_BBOXES)] return [*DEFAULT_SINGLE_BBOXES[random.randrange(len(DEFAULT_SINGLE_BBOXES))]] def _parse_manual_bboxes(spec: str) -> Optional[List[List[int]]]: if not spec or not spec.strip(): return None spec = spec.strip() try: parsed = json.loads(spec) except json.JSONDecodeError: parsed = [] for chunk in spec.split(";"): chunk = chunk.strip() if not chunk: continue values = [float(value.strip()) for value in chunk.split(",")] if len(values) != 4: raise ValueError(f"Expected 4 values per bbox, got {len(values)}: {chunk}") parsed.append(values) if isinstance(parsed, dict) and "bboxes" in parsed: parsed = parsed["bboxes"] if not isinstance(parsed, Sequence): raise ValueError("Bounding box specification must be a list or dictionary with 'bboxes'.") cleaned: List[List[int]] = [] for entry in parsed: if isinstance(entry, str): coords = [float(value.strip()) for value in entry.split(",")] elif isinstance(entry, Iterable): coords = [float(value) for value in entry] else: raise ValueError(f"Unsupported bbox entry type: {type(entry)}") if len(coords) != 4: raise ValueError(f"Each bbox needs four coordinates, received {coords}") cleaned.append([int(round(coord)) for coord in coords]) return cleaned def _scale_bboxes(bboxes: List[List[int]], width: int, height: int, reference: int = 512) -> List[List[int]]: if width == reference and height == reference: return bboxes sx = width / float(reference) sy = height / float(reference) scaled = [] for x1, y1, x2, y2 in bboxes: scaled.append( [ int(round(x1 * sx)), int(round(y1 * sy)), int(round(x2 * sx)), int(round(y2 * sy)), ] ) return scaled def _comfy_to_pil_batch(images: torch.Tensor) -> List[Image.Image]: if images.ndim == 3: images = images.unsqueeze(0) pil_images: List[Image.Image] = [] for image in images: array = image.detach().cpu().numpy() if array.dtype != np.float32 and array.dtype != np.float64: array = array.astype(np.float32) array = np.clip(array, 0.0, 1.0) array = (array * 255.0).astype(np.uint8) if array.shape[-1] == 4: array = array[..., :3] pil_images.append(Image.fromarray(array)) return pil_images def _pil_to_comfy_image(image: Image.Image) -> torch.Tensor: array = np.asarray(image.convert("RGB"), dtype=np.float32) / 255.0 tensor = torch.from_numpy(array) tensor = tensor.unsqueeze(0) # batch dimension return tensor def _prepare_references( images: torch.Tensor, device: torch.device, ) -> Tuple[List[Image.Image], torch.Tensor]: face_extractor = _get_face_extractor() ref_pil: List[Image.Image] = [] arc_embeddings: List[torch.Tensor] = [] for pil_image in _comfy_to_pil_batch(images): ref_img, embedding = face_extractor.extract(pil_image) if ref_img is None or embedding is None: raise RuntimeError("Failed to extract a face embedding from the provided reference image.") ref_pil.append(ref_img) arc_embeddings.append(torch.tensor(embedding, dtype=torch.float32, device=device)) arcface_tensor = torch.stack(arc_embeddings, dim=0) return ref_pil, arcface_tensor def _get_pipeline( model_type: str, ipa_path: str, clip_path: str, t5_path: str, flux_path: str, siglip_path: str, only_lora: bool, offload: bool, lora_rank: int, lora_weight: float, additional_lora: Optional[str], ) -> WithAnyonePipeline: device = _get_device() cache_key = ( model_type, ipa_path, clip_path, t5_path, flux_path, siglip_path, only_lora, offload, lora_rank, lora_weight, additional_lora, device.type, ) pipeline = PIPELINE_CACHE.get(cache_key) if pipeline is None: face_extractor = _get_face_extractor() pipeline = WithAnyonePipeline( model_type=model_type, ipa_path=ipa_path, device=device, offload=offload, only_lora=only_lora, lora_rank=lora_rank, face_extractor=face_extractor, additional_lora_ckpt=additional_lora, lora_weight=lora_weight, clip_path=clip_path, t5_path=t5_path, flux_path=flux_path, siglip_path=siglip_path, ) PIPELINE_CACHE[cache_key] = pipeline else: pipeline.device = device return pipeline class WithAnyoneNode: """ ComfyUI node that wraps the WithAnyone inference pipeline. """ @classmethod def INPUT_TYPES(cls): # noqa: N802 - ComfyUI API return { "required": { "prompt": ("STRING", {"multiline": True, "default": ""}), "ref_images": ("IMAGE",), }, "optional": { "manual_bboxes": ("STRING", {"default": ""}), "width": ("INT", {"default": 512, "min": 256, "max": 1024, "step": 16}), "height": ("INT", {"default": 512, "min": 256, "max": 1024, "step": 16}), "num_steps": ("INT", {"default": 25, "min": 5, "max": 100, "step": 1}), "guidance": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 25.0, "step": 0.1}), "seed": ("INT", {"default": 1234, "min": 0, "max": 2**32 - 1}), "model_type": (["flux-dev", "flux-dev-fp8", "flux-schnell"], {"default": "flux-dev"}), "id_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.05}), "siglip_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.05}), "only_lora": ("BOOLEAN", {"default": True}), "offload": ("BOOLEAN", {"default": False}), "lora_rank": ("INT", {"default": 64, "min": 1, "max": 128, "step": 1}), "lora_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05}), "additional_lora": ("STRING", {"default": ""}), "ipa_path": ("STRING", {"default": "WithAnyone/WithAnyone"}), "clip_path": ("STRING", {"default": "openai/clip-vit-large-patch14"}), "t5_path": ("STRING", {"default": "xlabs-ai/xflux_text_encoders"}), "flux_path": ("STRING", {"default": "black-forest-labs/FLUX.1-dev"}), "siglip_path": ("STRING", {"default": "google/siglip-base-patch16-256-i18n"}), }, } RETURN_TYPES = ("IMAGE", "DICT") RETURN_NAMES = ("image", "info") FUNCTION = "generate" CATEGORY = "withanyone" def _create_progress_bar(self, steps: int): if ProgressBar is None: return None return ProgressBar(steps) def generate( # noqa: C901 - ComfyUI entry points are typically long self, prompt: str, ref_images: torch.Tensor, manual_bboxes: str = "", width: int = 512, height: int = 512, num_steps: int = 25, guidance: float = 4.0, seed: int = 1234, model_type: str = "flux-dev", id_weight: float = 1.0, siglip_weight: float = 1.0, only_lora: bool = True, offload: bool = False, lora_rank: int = 64, lora_weight: float = 1.0, additional_lora: str = "", ipa_path: str = "WithAnyone/WithAnyone", clip_path: str = "openai/clip-vit-large-patch14", t5_path: str = "xlabs-ai/xflux_text_encoders", flux_path: str = "black-forest-labs/FLUX.1-dev", siglip_path: str = "google/siglip-base-patch16-256-i18n", ): additional_lora_ckpt = additional_lora if additional_lora.strip() else None device = _get_device() progress = self._create_progress_bar(num_steps) pipeline = _get_pipeline( model_type=model_type, ipa_path=ipa_path, clip_path=clip_path, t5_path=t5_path, flux_path=flux_path, siglip_path=siglip_path, only_lora=only_lora, offload=offload, lora_rank=lora_rank, lora_weight=lora_weight, additional_lora=additional_lora_ckpt, ) ref_imgs_pil, arcface_embeddings = _prepare_references(ref_images, device=device) parsed_bboxes = _parse_manual_bboxes(manual_bboxes) if parsed_bboxes is None: parsed_bboxes = _select_default_bboxes(len(ref_imgs_pil)) parsed_bboxes = _scale_bboxes(parsed_bboxes, width, height) result_image = pipeline( prompt=prompt, width=width, height=height, guidance=guidance, num_steps=num_steps, seed=seed, ref_imgs=ref_imgs_pil, arcface_embeddings=arcface_embeddings, bboxes=[parsed_bboxes], id_weight=id_weight, siglip_weight=siglip_weight, ) if progress is not None: progress.update_absolute(num_steps, num_steps) output_tensor = _pil_to_comfy_image(result_image) info = { "seed": seed, "width": width, "height": height, "guidance": guidance, "num_steps": num_steps, "bboxes": parsed_bboxes, "model_type": model_type, } return output_tensor, info NODE_CLASS_MAPPINGS = { "WithAnyoneGenerate": WithAnyoneNode, } NODE_DISPLAY_NAME_MAPPINGS = { "WithAnyoneGenerate": "WithAnyone (Flux)", }