|
|
""" |
|
|
ComfyUI custom node for running the WithAnyone pipeline. |
|
|
|
|
|
Copy or symlink this file into your ComfyUI `custom_nodes` directory and ensure |
|
|
the WithAnyone project plus its dependencies are available in the Python path. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
import random |
|
|
from typing import Dict, Iterable, List, Optional, Sequence, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
|
|
|
try: |
|
|
from comfy import model_management |
|
|
from comfy.utils import ProgressBar |
|
|
except ImportError: |
|
|
model_management = None |
|
|
ProgressBar = None |
|
|
|
|
|
from withanyone.flux.pipeline import WithAnyonePipeline |
|
|
from util import FaceExtractor |
|
|
|
|
|
|
|
|
DEFAULT_SINGLE_BBOXES: List[List[int]] = [ |
|
|
[150, 100, 250, 200], |
|
|
[100, 100, 200, 200], |
|
|
[200, 100, 300, 200], |
|
|
[250, 100, 350, 200], |
|
|
[300, 100, 400, 200], |
|
|
] |
|
|
|
|
|
DEFAULT_DOUBLE_BBOXES: List[List[List[int]]] = [ |
|
|
[[100, 100, 200, 200], [300, 100, 400, 200]], |
|
|
[[150, 100, 250, 200], [300, 100, 400, 200]], |
|
|
] |
|
|
|
|
|
PIPELINE_CACHE: Dict[Tuple, WithAnyonePipeline] = {} |
|
|
FACE_EXTRACTOR: Optional[FaceExtractor] = None |
|
|
|
|
|
|
|
|
def _get_device() -> torch.device: |
|
|
if model_management is not None: |
|
|
return model_management.get_torch_device() |
|
|
return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
def _get_face_extractor() -> FaceExtractor: |
|
|
global FACE_EXTRACTOR |
|
|
if FACE_EXTRACTOR is None: |
|
|
FACE_EXTRACTOR = FaceExtractor() |
|
|
return FACE_EXTRACTOR |
|
|
|
|
|
|
|
|
def _select_default_bboxes(identity_count: int) -> List[List[int]]: |
|
|
if identity_count >= 2: |
|
|
return [*random.choice(DEFAULT_DOUBLE_BBOXES)] |
|
|
return [*DEFAULT_SINGLE_BBOXES[random.randrange(len(DEFAULT_SINGLE_BBOXES))]] |
|
|
|
|
|
|
|
|
def _parse_manual_bboxes(spec: str) -> Optional[List[List[int]]]: |
|
|
if not spec or not spec.strip(): |
|
|
return None |
|
|
|
|
|
spec = spec.strip() |
|
|
try: |
|
|
parsed = json.loads(spec) |
|
|
except json.JSONDecodeError: |
|
|
parsed = [] |
|
|
for chunk in spec.split(";"): |
|
|
chunk = chunk.strip() |
|
|
if not chunk: |
|
|
continue |
|
|
values = [float(value.strip()) for value in chunk.split(",")] |
|
|
if len(values) != 4: |
|
|
raise ValueError(f"Expected 4 values per bbox, got {len(values)}: {chunk}") |
|
|
parsed.append(values) |
|
|
|
|
|
if isinstance(parsed, dict) and "bboxes" in parsed: |
|
|
parsed = parsed["bboxes"] |
|
|
|
|
|
if not isinstance(parsed, Sequence): |
|
|
raise ValueError("Bounding box specification must be a list or dictionary with 'bboxes'.") |
|
|
|
|
|
cleaned: List[List[int]] = [] |
|
|
for entry in parsed: |
|
|
if isinstance(entry, str): |
|
|
coords = [float(value.strip()) for value in entry.split(",")] |
|
|
elif isinstance(entry, Iterable): |
|
|
coords = [float(value) for value in entry] |
|
|
else: |
|
|
raise ValueError(f"Unsupported bbox entry type: {type(entry)}") |
|
|
|
|
|
if len(coords) != 4: |
|
|
raise ValueError(f"Each bbox needs four coordinates, received {coords}") |
|
|
|
|
|
cleaned.append([int(round(coord)) for coord in coords]) |
|
|
|
|
|
return cleaned |
|
|
|
|
|
|
|
|
def _scale_bboxes(bboxes: List[List[int]], width: int, height: int, reference: int = 512) -> List[List[int]]: |
|
|
if width == reference and height == reference: |
|
|
return bboxes |
|
|
|
|
|
sx = width / float(reference) |
|
|
sy = height / float(reference) |
|
|
scaled = [] |
|
|
for x1, y1, x2, y2 in bboxes: |
|
|
scaled.append( |
|
|
[ |
|
|
int(round(x1 * sx)), |
|
|
int(round(y1 * sy)), |
|
|
int(round(x2 * sx)), |
|
|
int(round(y2 * sy)), |
|
|
] |
|
|
) |
|
|
return scaled |
|
|
|
|
|
|
|
|
def _comfy_to_pil_batch(images: torch.Tensor) -> List[Image.Image]: |
|
|
if images.ndim == 3: |
|
|
images = images.unsqueeze(0) |
|
|
pil_images: List[Image.Image] = [] |
|
|
for image in images: |
|
|
array = image.detach().cpu().numpy() |
|
|
if array.dtype != np.float32 and array.dtype != np.float64: |
|
|
array = array.astype(np.float32) |
|
|
array = np.clip(array, 0.0, 1.0) |
|
|
array = (array * 255.0).astype(np.uint8) |
|
|
if array.shape[-1] == 4: |
|
|
array = array[..., :3] |
|
|
pil_images.append(Image.fromarray(array)) |
|
|
return pil_images |
|
|
|
|
|
|
|
|
def _pil_to_comfy_image(image: Image.Image) -> torch.Tensor: |
|
|
array = np.asarray(image.convert("RGB"), dtype=np.float32) / 255.0 |
|
|
tensor = torch.from_numpy(array) |
|
|
tensor = tensor.unsqueeze(0) |
|
|
return tensor |
|
|
|
|
|
|
|
|
def _prepare_references( |
|
|
images: torch.Tensor, |
|
|
device: torch.device, |
|
|
) -> Tuple[List[Image.Image], torch.Tensor]: |
|
|
face_extractor = _get_face_extractor() |
|
|
ref_pil: List[Image.Image] = [] |
|
|
arc_embeddings: List[torch.Tensor] = [] |
|
|
|
|
|
for pil_image in _comfy_to_pil_batch(images): |
|
|
ref_img, embedding = face_extractor.extract(pil_image) |
|
|
if ref_img is None or embedding is None: |
|
|
raise RuntimeError("Failed to extract a face embedding from the provided reference image.") |
|
|
ref_pil.append(ref_img) |
|
|
arc_embeddings.append(torch.tensor(embedding, dtype=torch.float32, device=device)) |
|
|
|
|
|
arcface_tensor = torch.stack(arc_embeddings, dim=0) |
|
|
return ref_pil, arcface_tensor |
|
|
|
|
|
|
|
|
def _get_pipeline( |
|
|
model_type: str, |
|
|
ipa_path: str, |
|
|
clip_path: str, |
|
|
t5_path: str, |
|
|
flux_path: str, |
|
|
siglip_path: str, |
|
|
only_lora: bool, |
|
|
offload: bool, |
|
|
lora_rank: int, |
|
|
lora_weight: float, |
|
|
additional_lora: Optional[str], |
|
|
) -> WithAnyonePipeline: |
|
|
device = _get_device() |
|
|
cache_key = ( |
|
|
model_type, |
|
|
ipa_path, |
|
|
clip_path, |
|
|
t5_path, |
|
|
flux_path, |
|
|
siglip_path, |
|
|
only_lora, |
|
|
offload, |
|
|
lora_rank, |
|
|
lora_weight, |
|
|
additional_lora, |
|
|
device.type, |
|
|
) |
|
|
|
|
|
pipeline = PIPELINE_CACHE.get(cache_key) |
|
|
if pipeline is None: |
|
|
face_extractor = _get_face_extractor() |
|
|
pipeline = WithAnyonePipeline( |
|
|
model_type=model_type, |
|
|
ipa_path=ipa_path, |
|
|
device=device, |
|
|
offload=offload, |
|
|
only_lora=only_lora, |
|
|
lora_rank=lora_rank, |
|
|
face_extractor=face_extractor, |
|
|
additional_lora_ckpt=additional_lora, |
|
|
lora_weight=lora_weight, |
|
|
clip_path=clip_path, |
|
|
t5_path=t5_path, |
|
|
flux_path=flux_path, |
|
|
siglip_path=siglip_path, |
|
|
) |
|
|
PIPELINE_CACHE[cache_key] = pipeline |
|
|
else: |
|
|
pipeline.device = device |
|
|
|
|
|
return pipeline |
|
|
|
|
|
|
|
|
class WithAnyoneNode: |
|
|
""" |
|
|
ComfyUI node that wraps the WithAnyone inference pipeline. |
|
|
""" |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(cls): |
|
|
return { |
|
|
"required": { |
|
|
"prompt": ("STRING", {"multiline": True, "default": ""}), |
|
|
"ref_images": ("IMAGE",), |
|
|
}, |
|
|
"optional": { |
|
|
"manual_bboxes": ("STRING", {"default": ""}), |
|
|
"width": ("INT", {"default": 512, "min": 256, "max": 1024, "step": 16}), |
|
|
"height": ("INT", {"default": 512, "min": 256, "max": 1024, "step": 16}), |
|
|
"num_steps": ("INT", {"default": 25, "min": 5, "max": 100, "step": 1}), |
|
|
"guidance": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 25.0, "step": 0.1}), |
|
|
"seed": ("INT", {"default": 1234, "min": 0, "max": 2**32 - 1}), |
|
|
"model_type": (["flux-dev", "flux-dev-fp8", "flux-schnell"], {"default": "flux-dev"}), |
|
|
"id_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.05}), |
|
|
"siglip_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.05}), |
|
|
"only_lora": ("BOOLEAN", {"default": True}), |
|
|
"offload": ("BOOLEAN", {"default": False}), |
|
|
"lora_rank": ("INT", {"default": 64, "min": 1, "max": 128, "step": 1}), |
|
|
"lora_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05}), |
|
|
"additional_lora": ("STRING", {"default": ""}), |
|
|
"ipa_path": ("STRING", {"default": "WithAnyone/WithAnyone"}), |
|
|
"clip_path": ("STRING", {"default": "openai/clip-vit-large-patch14"}), |
|
|
"t5_path": ("STRING", {"default": "xlabs-ai/xflux_text_encoders"}), |
|
|
"flux_path": ("STRING", {"default": "black-forest-labs/FLUX.1-dev"}), |
|
|
"siglip_path": ("STRING", {"default": "google/siglip-base-patch16-256-i18n"}), |
|
|
}, |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("IMAGE", "DICT") |
|
|
RETURN_NAMES = ("image", "info") |
|
|
FUNCTION = "generate" |
|
|
CATEGORY = "withanyone" |
|
|
|
|
|
def _create_progress_bar(self, steps: int): |
|
|
if ProgressBar is None: |
|
|
return None |
|
|
return ProgressBar(steps) |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
ref_images: torch.Tensor, |
|
|
manual_bboxes: str = "", |
|
|
width: int = 512, |
|
|
height: int = 512, |
|
|
num_steps: int = 25, |
|
|
guidance: float = 4.0, |
|
|
seed: int = 1234, |
|
|
model_type: str = "flux-dev", |
|
|
id_weight: float = 1.0, |
|
|
siglip_weight: float = 1.0, |
|
|
only_lora: bool = True, |
|
|
offload: bool = False, |
|
|
lora_rank: int = 64, |
|
|
lora_weight: float = 1.0, |
|
|
additional_lora: str = "", |
|
|
ipa_path: str = "WithAnyone/WithAnyone", |
|
|
clip_path: str = "openai/clip-vit-large-patch14", |
|
|
t5_path: str = "xlabs-ai/xflux_text_encoders", |
|
|
flux_path: str = "black-forest-labs/FLUX.1-dev", |
|
|
siglip_path: str = "google/siglip-base-patch16-256-i18n", |
|
|
): |
|
|
additional_lora_ckpt = additional_lora if additional_lora.strip() else None |
|
|
device = _get_device() |
|
|
progress = self._create_progress_bar(num_steps) |
|
|
|
|
|
pipeline = _get_pipeline( |
|
|
model_type=model_type, |
|
|
ipa_path=ipa_path, |
|
|
clip_path=clip_path, |
|
|
t5_path=t5_path, |
|
|
flux_path=flux_path, |
|
|
siglip_path=siglip_path, |
|
|
only_lora=only_lora, |
|
|
offload=offload, |
|
|
lora_rank=lora_rank, |
|
|
lora_weight=lora_weight, |
|
|
additional_lora=additional_lora_ckpt, |
|
|
) |
|
|
|
|
|
ref_imgs_pil, arcface_embeddings = _prepare_references(ref_images, device=device) |
|
|
|
|
|
parsed_bboxes = _parse_manual_bboxes(manual_bboxes) |
|
|
if parsed_bboxes is None: |
|
|
parsed_bboxes = _select_default_bboxes(len(ref_imgs_pil)) |
|
|
parsed_bboxes = _scale_bboxes(parsed_bboxes, width, height) |
|
|
|
|
|
result_image = pipeline( |
|
|
prompt=prompt, |
|
|
width=width, |
|
|
height=height, |
|
|
guidance=guidance, |
|
|
num_steps=num_steps, |
|
|
seed=seed, |
|
|
ref_imgs=ref_imgs_pil, |
|
|
arcface_embeddings=arcface_embeddings, |
|
|
bboxes=[parsed_bboxes], |
|
|
id_weight=id_weight, |
|
|
siglip_weight=siglip_weight, |
|
|
) |
|
|
|
|
|
if progress is not None: |
|
|
progress.update_absolute(num_steps, num_steps) |
|
|
|
|
|
output_tensor = _pil_to_comfy_image(result_image) |
|
|
info = { |
|
|
"seed": seed, |
|
|
"width": width, |
|
|
"height": height, |
|
|
"guidance": guidance, |
|
|
"num_steps": num_steps, |
|
|
"bboxes": parsed_bboxes, |
|
|
"model_type": model_type, |
|
|
} |
|
|
|
|
|
return output_tensor, info |
|
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
"WithAnyoneGenerate": WithAnyoneNode, |
|
|
} |
|
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
|
"WithAnyoneGenerate": "WithAnyone (Flux)", |
|
|
} |
|
|
|
|
|
|