|
|
import os |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from huggingface_hub import snapshot_download |
|
|
from PIL import Image |
|
|
from torchvision.transforms.functional import to_pil_image, to_tensor |
|
|
|
|
|
|
|
|
from .parse_utils.automasker import ( |
|
|
cloth_agnostic_mask, |
|
|
multi_ref_cloth_agnostic_mask, |
|
|
) |
|
|
from .module.pipeline_fastfit import FastFitPipeline |
|
|
|
|
|
|
|
|
from .parse_utils import DWposeDetector, DensePose, SCHP |
|
|
from .module.utils import resize_and_crop, resize_and_padding |
|
|
|
|
|
|
|
|
|
|
|
def parser_output_to_image_tensor(pil_image): |
|
|
""" |
|
|
Converts a PIL image to an IMAGE tensor, handling 'P' mode images |
|
|
by preserving their original index values in the RGB channels. |
|
|
|
|
|
Args: |
|
|
pil_image (PIL.Image.Image): The input PIL image. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The resulting image tensor. |
|
|
""" |
|
|
if pil_image.mode == 'P': |
|
|
|
|
|
image_np = np.array(pil_image) |
|
|
|
|
|
|
|
|
|
|
|
rgb_np = np.stack([image_np, image_np, image_np], axis=-1) |
|
|
|
|
|
|
|
|
pil_image = Image.fromarray(rgb_np, 'RGB') |
|
|
elif pil_image.mode != "RGB": |
|
|
|
|
|
pil_image = pil_image.convert("RGB") |
|
|
|
|
|
|
|
|
return to_tensor(pil_image).permute(1, 2, 0).unsqueeze(0) |
|
|
|
|
|
|
|
|
class LoadFastFit: |
|
|
display_name = "Load FastFit" |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(cls): |
|
|
return { |
|
|
"required": { |
|
|
"fastfit_path": ("STRING", {"default": "RedHash/FastFit"}), |
|
|
"mixed_precision": (["bf16", "fp32", "fp16"],), |
|
|
} |
|
|
} |
|
|
|
|
|
RETURN_TYPES, RETURN_NAMES, FUNCTION, CATEGORY = ( |
|
|
("MODEL",), |
|
|
("fastfit_pipeline",), |
|
|
"load", |
|
|
"FastFit/Loaders", |
|
|
) |
|
|
|
|
|
def load(self, fastfit_path, mixed_precision): |
|
|
if not os.path.exists(fastfit_path): |
|
|
fastfit_path = snapshot_download(repo_id=fastfit_path) |
|
|
return ( |
|
|
FastFitPipeline( |
|
|
base_model_path=fastfit_path, |
|
|
device="cuda", |
|
|
mixed_precision=mixed_precision, |
|
|
allow_tf32=True, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
class FastFitPipelineNode: |
|
|
display_name = "FastFit Pipeline" |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(cls): |
|
|
return { |
|
|
"required": { |
|
|
"fastfit_pipeline": ("MODEL",), |
|
|
"person": ("IMAGE",), |
|
|
"pose": ("IMAGE",), |
|
|
"mask": ("MASK",), |
|
|
"num_inference_steps": ("INT", {"default": 50}), |
|
|
"guidance_scale": ("FLOAT", {"default": 2.5}), |
|
|
"generator": ( |
|
|
"INT", |
|
|
{"default": 42, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}, |
|
|
), |
|
|
}, |
|
|
"optional": { |
|
|
"upper_image": ("IMAGE",), |
|
|
"lower_image": ("IMAGE",), |
|
|
"dress_image": ("IMAGE",), |
|
|
"shoe_image": ("IMAGE",), |
|
|
"bag_image": ("IMAGE",), |
|
|
}, |
|
|
} |
|
|
|
|
|
RETURN_TYPES, RETURN_NAMES, FUNCTION, CATEGORY = ( |
|
|
("IMAGE",), |
|
|
("image",), |
|
|
"generate", |
|
|
"FastFit/Pipelines", |
|
|
) |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
fastfit_pipeline, |
|
|
person, |
|
|
pose, |
|
|
mask, |
|
|
num_inference_steps, |
|
|
guidance_scale, |
|
|
generator, |
|
|
upper_image=None, |
|
|
lower_image=None, |
|
|
dress_image=None, |
|
|
shoe_image=None, |
|
|
bag_image=None, |
|
|
): |
|
|
person_image = to_pil_image(person.squeeze(0).permute(2, 0, 1)) |
|
|
person_image = resize_and_crop(person_image) |
|
|
pose_image = to_pil_image(pose.squeeze(0).permute(2, 0, 1)) |
|
|
|
|
|
mask_image = to_pil_image(mask) |
|
|
|
|
|
|
|
|
ref_images, ref_labels, ref_attention_masks = [], [], [] |
|
|
|
|
|
ordered_items = { |
|
|
"upper": upper_image, |
|
|
"lower": lower_image, |
|
|
"overall": dress_image, |
|
|
"shoe": shoe_image, |
|
|
"bag": bag_image, |
|
|
} |
|
|
|
|
|
for label, img_tensor in ordered_items.items(): |
|
|
if img_tensor is not None: |
|
|
img_pil = to_pil_image(img_tensor.squeeze(0).permute(2, 0, 1)) |
|
|
|
|
|
target_size = (384, 512) |
|
|
img_pil = resize_and_padding(img_pil, target_size) |
|
|
ref_images.append(img_pil) |
|
|
ref_labels.append(label) |
|
|
ref_attention_masks.append(1) |
|
|
|
|
|
if not ref_images: |
|
|
raise ValueError("At least one reference image must be provided.") |
|
|
|
|
|
gen = torch.Generator(device="cuda").manual_seed(generator) |
|
|
result_image = fastfit_pipeline( |
|
|
person=person_image, |
|
|
mask=mask_image, |
|
|
ref_images=ref_images, |
|
|
ref_labels=ref_labels, |
|
|
ref_attention_masks=ref_attention_masks, |
|
|
pose=pose_image, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
generator=gen, |
|
|
return_pil=True, |
|
|
do_adjust_input_image=True, |
|
|
)[0] |
|
|
|
|
|
return (to_tensor(result_image).permute(1, 2, 0).unsqueeze(0),) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoadHumanParsers: |
|
|
display_name = "Load Human Parsers (Unified)" |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(cls): |
|
|
return { |
|
|
"required": { |
|
|
"human_toolkit_path": ( |
|
|
"STRING", |
|
|
{"default": "zhengchong/Human-Toolkit"}, |
|
|
), |
|
|
} |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("HUMAN_PARSERS",) |
|
|
FUNCTION = "load" |
|
|
CATEGORY = "FastFit/Loaders" |
|
|
|
|
|
def load(self, human_toolkit_path): |
|
|
|
|
|
if not os.path.exists(human_toolkit_path): |
|
|
human_toolkit_path = snapshot_download(repo_id=human_toolkit_path) |
|
|
dwpose_detector = DWposeDetector( |
|
|
pretrained_model_name_or_path=os.path.join(human_toolkit_path, "DWPose"), |
|
|
device="cpu", |
|
|
) |
|
|
densepose_detector = DensePose( |
|
|
model_path=os.path.join(human_toolkit_path, "DensePose"), device="cuda" |
|
|
) |
|
|
schp_lip_detector = SCHP( |
|
|
ckpt_path=os.path.join(human_toolkit_path, "SCHP", "schp-lip.pth"), |
|
|
device="cuda", |
|
|
) |
|
|
schp_atr_detector = SCHP( |
|
|
ckpt_path=os.path.join(human_toolkit_path, "SCHP", "schp-atr.pth"), |
|
|
device="cuda", |
|
|
) |
|
|
|
|
|
parsers = ( |
|
|
dwpose_detector, |
|
|
densepose_detector, |
|
|
schp_lip_detector, |
|
|
schp_atr_detector, |
|
|
) |
|
|
return (parsers,) |
|
|
|
|
|
|
|
|
class UnifiedHumanParserNode: |
|
|
display_name = "Run Human Parsers (Unified)" |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(cls): |
|
|
return { |
|
|
"required": { |
|
|
"human_parsers": ("HUMAN_PARSERS",), |
|
|
"image": ("IMAGE",), |
|
|
} |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ( |
|
|
"IMAGE", |
|
|
"IMAGE", |
|
|
"IMAGE", |
|
|
"IMAGE", |
|
|
) |
|
|
RETURN_NAMES = ( |
|
|
"pose_image", |
|
|
"densepose_map", |
|
|
"lip_map", |
|
|
"atr_map", |
|
|
) |
|
|
FUNCTION = "run_parsers" |
|
|
CATEGORY = "FastFit/Detectors" |
|
|
|
|
|
def run_parsers(self, human_parsers, image): |
|
|
dwpose, densepose, lip, atr = human_parsers |
|
|
|
|
|
pil_image = to_pil_image(image.squeeze(0).permute(2, 0, 1)) |
|
|
processed_pil = resize_and_crop(pil_image) |
|
|
|
|
|
|
|
|
|
|
|
pose_image_pil = dwpose(processed_pil) |
|
|
densepose_map_pil = densepose(processed_pil) |
|
|
lip_map_pil = lip(processed_pil) |
|
|
atr_map_pil = atr(processed_pil) |
|
|
|
|
|
|
|
|
pose_tensor = parser_output_to_image_tensor(pose_image_pil) |
|
|
densepose_tensor = parser_output_to_image_tensor(densepose_map_pil) |
|
|
lip_tensor = parser_output_to_image_tensor(lip_map_pil) |
|
|
atr_tensor = parser_output_to_image_tensor(atr_map_pil) |
|
|
|
|
|
return (pose_tensor, densepose_tensor, lip_tensor, atr_tensor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AutoMaskerNode: |
|
|
display_name = "Auto Masker" |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(cls): |
|
|
return { |
|
|
"required": { |
|
|
"densepose_map": ("IMAGE",), |
|
|
"lip_map": ("IMAGE",), |
|
|
"atr_map": ("IMAGE",), |
|
|
"mode": ( |
|
|
["multi_ref", "upper", "lower", "overall", "inner", "outer"], |
|
|
{"default": "multi_ref"}, |
|
|
), |
|
|
"square_mask": ("BOOLEAN", {"default": False}), |
|
|
"horizon_expand": ("BOOLEAN", {"default": True}), |
|
|
} |
|
|
} |
|
|
|
|
|
RETURN_TYPES, RETURN_NAMES, FUNCTION, CATEGORY = ( |
|
|
("MASK",), |
|
|
("mask",), |
|
|
"generate", |
|
|
"FastFit/Masking", |
|
|
) |
|
|
|
|
|
def _convert_input_tensor(self, tensor: torch.Tensor) -> np.ndarray: |
|
|
np_array = tensor.squeeze(0).cpu().numpy() |
|
|
|
|
|
if np_array.ndim == 3: |
|
|
return (np_array[:, :, 0] * 255).astype(np.uint8) |
|
|
return (np_array * 255).astype(np.uint8) |
|
|
|
|
|
def _convert_output_pil(self, pil_image: Image.Image) -> torch.Tensor: |
|
|
np_mask = np.array(pil_image.convert("L")).astype(np.float32) / 255.0 |
|
|
return torch.from_numpy(np_mask).unsqueeze(0) |
|
|
|
|
|
def generate( |
|
|
self, densepose_map, lip_map, atr_map, mode, square_mask, horizon_expand |
|
|
): |
|
|
densepose_arr = self._convert_input_tensor(densepose_map) |
|
|
lip_arr = self._convert_input_tensor(lip_map) |
|
|
atr_arr = self._convert_input_tensor(atr_map) |
|
|
|
|
|
if mode == "multi_ref": |
|
|
mask_pil = multi_ref_cloth_agnostic_mask( |
|
|
densepose_arr, |
|
|
lip_arr, |
|
|
atr_arr, |
|
|
square_cloth_mask=square_mask, |
|
|
horizon_expand=horizon_expand, |
|
|
) |
|
|
else: |
|
|
mask_pil = cloth_agnostic_mask( |
|
|
densepose_arr, |
|
|
lip_arr, |
|
|
atr_arr, |
|
|
part=mode, |
|
|
square_cloth_mask=square_mask, |
|
|
) |
|
|
|
|
|
return (self._convert_output_pil(mask_pil),) |
|
|
|
|
|
|
|
|
class MaskSelectorNode: |
|
|
display_name = "Mask Selector" |
|
|
description = "Selects the manual mask if provided, otherwise falls back to the auto-generated mask." |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(cls): |
|
|
return { |
|
|
"required": { |
|
|
"auto_mask": ("MASK",), |
|
|
}, |
|
|
"optional": { |
|
|
"manual_mask": ("MASK",), |
|
|
}, |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("MASK",) |
|
|
RETURN_NAMES = ("selected_mask",) |
|
|
FUNCTION = "select_mask" |
|
|
CATEGORY = "FastFit/Masking" |
|
|
|
|
|
def select_mask(self, auto_mask, manual_mask=None): |
|
|
if manual_mask is not None: |
|
|
|
|
|
|
|
|
if torch.any(manual_mask > 0): |
|
|
return (manual_mask,) |
|
|
|
|
|
|
|
|
return (auto_mask,) |
|
|
|
|
|
|
|
|
_export_classes = [ |
|
|
LoadFastFit, |
|
|
FastFitPipelineNode, |
|
|
LoadHumanParsers, |
|
|
UnifiedHumanParserNode, |
|
|
AutoMaskerNode, |
|
|
MaskSelectorNode, |
|
|
] |
|
|
|
|
|
NODE_CLASS_MAPPINGS = {c.__name__: c for c in _export_classes} |
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
|
c.__name__: getattr(c, "display_name", c.__name__) for c in _export_classes |
|
|
} |
|
|
|