Cna2 / comfyui /withanyone_node.py
Badnerle1234454's picture
Upload 2 files
dde8ccd verified
"""
ComfyUI custom node for running the WithAnyone pipeline.
Copy or symlink this file into your ComfyUI `custom_nodes` directory and ensure
the WithAnyone project plus its dependencies are available in the Python path.
"""
from __future__ import annotations
import json
import random
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import numpy as np
import torch
from PIL import Image
try:
from comfy import model_management
from comfy.utils import ProgressBar
except ImportError: # pragma: no cover - only executed outside ComfyUI
model_management = None # type: ignore
ProgressBar = None # type: ignore
from withanyone.flux.pipeline import WithAnyonePipeline
from util import FaceExtractor
DEFAULT_SINGLE_BBOXES: List[List[int]] = [
[150, 100, 250, 200],
[100, 100, 200, 200],
[200, 100, 300, 200],
[250, 100, 350, 200],
[300, 100, 400, 200],
]
DEFAULT_DOUBLE_BBOXES: List[List[List[int]]] = [
[[100, 100, 200, 200], [300, 100, 400, 200]],
[[150, 100, 250, 200], [300, 100, 400, 200]],
]
PIPELINE_CACHE: Dict[Tuple, WithAnyonePipeline] = {}
FACE_EXTRACTOR: Optional[FaceExtractor] = None
def _get_device() -> torch.device:
if model_management is not None:
return model_management.get_torch_device()
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _get_face_extractor() -> FaceExtractor:
global FACE_EXTRACTOR
if FACE_EXTRACTOR is None:
FACE_EXTRACTOR = FaceExtractor()
return FACE_EXTRACTOR
def _select_default_bboxes(identity_count: int) -> List[List[int]]:
if identity_count >= 2:
return [*random.choice(DEFAULT_DOUBLE_BBOXES)]
return [*DEFAULT_SINGLE_BBOXES[random.randrange(len(DEFAULT_SINGLE_BBOXES))]]
def _parse_manual_bboxes(spec: str) -> Optional[List[List[int]]]:
if not spec or not spec.strip():
return None
spec = spec.strip()
try:
parsed = json.loads(spec)
except json.JSONDecodeError:
parsed = []
for chunk in spec.split(";"):
chunk = chunk.strip()
if not chunk:
continue
values = [float(value.strip()) for value in chunk.split(",")]
if len(values) != 4:
raise ValueError(f"Expected 4 values per bbox, got {len(values)}: {chunk}")
parsed.append(values)
if isinstance(parsed, dict) and "bboxes" in parsed:
parsed = parsed["bboxes"]
if not isinstance(parsed, Sequence):
raise ValueError("Bounding box specification must be a list or dictionary with 'bboxes'.")
cleaned: List[List[int]] = []
for entry in parsed:
if isinstance(entry, str):
coords = [float(value.strip()) for value in entry.split(",")]
elif isinstance(entry, Iterable):
coords = [float(value) for value in entry]
else:
raise ValueError(f"Unsupported bbox entry type: {type(entry)}")
if len(coords) != 4:
raise ValueError(f"Each bbox needs four coordinates, received {coords}")
cleaned.append([int(round(coord)) for coord in coords])
return cleaned
def _scale_bboxes(bboxes: List[List[int]], width: int, height: int, reference: int = 512) -> List[List[int]]:
if width == reference and height == reference:
return bboxes
sx = width / float(reference)
sy = height / float(reference)
scaled = []
for x1, y1, x2, y2 in bboxes:
scaled.append(
[
int(round(x1 * sx)),
int(round(y1 * sy)),
int(round(x2 * sx)),
int(round(y2 * sy)),
]
)
return scaled
def _comfy_to_pil_batch(images: torch.Tensor) -> List[Image.Image]:
if images.ndim == 3:
images = images.unsqueeze(0)
pil_images: List[Image.Image] = []
for image in images:
array = image.detach().cpu().numpy()
if array.dtype != np.float32 and array.dtype != np.float64:
array = array.astype(np.float32)
array = np.clip(array, 0.0, 1.0)
array = (array * 255.0).astype(np.uint8)
if array.shape[-1] == 4:
array = array[..., :3]
pil_images.append(Image.fromarray(array))
return pil_images
def _pil_to_comfy_image(image: Image.Image) -> torch.Tensor:
array = np.asarray(image.convert("RGB"), dtype=np.float32) / 255.0
tensor = torch.from_numpy(array)
tensor = tensor.unsqueeze(0) # batch dimension
return tensor
def _prepare_references(
images: torch.Tensor,
device: torch.device,
) -> Tuple[List[Image.Image], torch.Tensor]:
face_extractor = _get_face_extractor()
ref_pil: List[Image.Image] = []
arc_embeddings: List[torch.Tensor] = []
for pil_image in _comfy_to_pil_batch(images):
ref_img, embedding = face_extractor.extract(pil_image)
if ref_img is None or embedding is None:
raise RuntimeError("Failed to extract a face embedding from the provided reference image.")
ref_pil.append(ref_img)
arc_embeddings.append(torch.tensor(embedding, dtype=torch.float32, device=device))
arcface_tensor = torch.stack(arc_embeddings, dim=0)
return ref_pil, arcface_tensor
def _get_pipeline(
model_type: str,
ipa_path: str,
clip_path: str,
t5_path: str,
flux_path: str,
siglip_path: str,
only_lora: bool,
offload: bool,
lora_rank: int,
lora_weight: float,
additional_lora: Optional[str],
) -> WithAnyonePipeline:
device = _get_device()
cache_key = (
model_type,
ipa_path,
clip_path,
t5_path,
flux_path,
siglip_path,
only_lora,
offload,
lora_rank,
lora_weight,
additional_lora,
device.type,
)
pipeline = PIPELINE_CACHE.get(cache_key)
if pipeline is None:
face_extractor = _get_face_extractor()
pipeline = WithAnyonePipeline(
model_type=model_type,
ipa_path=ipa_path,
device=device,
offload=offload,
only_lora=only_lora,
lora_rank=lora_rank,
face_extractor=face_extractor,
additional_lora_ckpt=additional_lora,
lora_weight=lora_weight,
clip_path=clip_path,
t5_path=t5_path,
flux_path=flux_path,
siglip_path=siglip_path,
)
PIPELINE_CACHE[cache_key] = pipeline
else:
pipeline.device = device
return pipeline
class WithAnyoneNode:
"""
ComfyUI node that wraps the WithAnyone inference pipeline.
"""
@classmethod
def INPUT_TYPES(cls): # noqa: N802 - ComfyUI API
return {
"required": {
"prompt": ("STRING", {"multiline": True, "default": ""}),
"ref_images": ("IMAGE",),
},
"optional": {
"manual_bboxes": ("STRING", {"default": ""}),
"width": ("INT", {"default": 512, "min": 256, "max": 1024, "step": 16}),
"height": ("INT", {"default": 512, "min": 256, "max": 1024, "step": 16}),
"num_steps": ("INT", {"default": 25, "min": 5, "max": 100, "step": 1}),
"guidance": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 25.0, "step": 0.1}),
"seed": ("INT", {"default": 1234, "min": 0, "max": 2**32 - 1}),
"model_type": (["flux-dev", "flux-dev-fp8", "flux-schnell"], {"default": "flux-dev"}),
"id_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.05}),
"siglip_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.05}),
"only_lora": ("BOOLEAN", {"default": True}),
"offload": ("BOOLEAN", {"default": False}),
"lora_rank": ("INT", {"default": 64, "min": 1, "max": 128, "step": 1}),
"lora_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05}),
"additional_lora": ("STRING", {"default": ""}),
"ipa_path": ("STRING", {"default": "WithAnyone/WithAnyone"}),
"clip_path": ("STRING", {"default": "openai/clip-vit-large-patch14"}),
"t5_path": ("STRING", {"default": "xlabs-ai/xflux_text_encoders"}),
"flux_path": ("STRING", {"default": "black-forest-labs/FLUX.1-dev"}),
"siglip_path": ("STRING", {"default": "google/siglip-base-patch16-256-i18n"}),
},
}
RETURN_TYPES = ("IMAGE", "DICT")
RETURN_NAMES = ("image", "info")
FUNCTION = "generate"
CATEGORY = "withanyone"
def _create_progress_bar(self, steps: int):
if ProgressBar is None:
return None
return ProgressBar(steps)
def generate( # noqa: C901 - ComfyUI entry points are typically long
self,
prompt: str,
ref_images: torch.Tensor,
manual_bboxes: str = "",
width: int = 512,
height: int = 512,
num_steps: int = 25,
guidance: float = 4.0,
seed: int = 1234,
model_type: str = "flux-dev",
id_weight: float = 1.0,
siglip_weight: float = 1.0,
only_lora: bool = True,
offload: bool = False,
lora_rank: int = 64,
lora_weight: float = 1.0,
additional_lora: str = "",
ipa_path: str = "WithAnyone/WithAnyone",
clip_path: str = "openai/clip-vit-large-patch14",
t5_path: str = "xlabs-ai/xflux_text_encoders",
flux_path: str = "black-forest-labs/FLUX.1-dev",
siglip_path: str = "google/siglip-base-patch16-256-i18n",
):
additional_lora_ckpt = additional_lora if additional_lora.strip() else None
device = _get_device()
progress = self._create_progress_bar(num_steps)
pipeline = _get_pipeline(
model_type=model_type,
ipa_path=ipa_path,
clip_path=clip_path,
t5_path=t5_path,
flux_path=flux_path,
siglip_path=siglip_path,
only_lora=only_lora,
offload=offload,
lora_rank=lora_rank,
lora_weight=lora_weight,
additional_lora=additional_lora_ckpt,
)
ref_imgs_pil, arcface_embeddings = _prepare_references(ref_images, device=device)
parsed_bboxes = _parse_manual_bboxes(manual_bboxes)
if parsed_bboxes is None:
parsed_bboxes = _select_default_bboxes(len(ref_imgs_pil))
parsed_bboxes = _scale_bboxes(parsed_bboxes, width, height)
result_image = pipeline(
prompt=prompt,
width=width,
height=height,
guidance=guidance,
num_steps=num_steps,
seed=seed,
ref_imgs=ref_imgs_pil,
arcface_embeddings=arcface_embeddings,
bboxes=[parsed_bboxes],
id_weight=id_weight,
siglip_weight=siglip_weight,
)
if progress is not None:
progress.update_absolute(num_steps, num_steps)
output_tensor = _pil_to_comfy_image(result_image)
info = {
"seed": seed,
"width": width,
"height": height,
"guidance": guidance,
"num_steps": num_steps,
"bboxes": parsed_bboxes,
"model_type": model_type,
}
return output_tensor, info
NODE_CLASS_MAPPINGS = {
"WithAnyoneGenerate": WithAnyoneNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WithAnyoneGenerate": "WithAnyone (Flux)",
}