Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""Automatic face/body enhancement processor for LightDiffusion-Next.
This processor uses detection models to identify faces and bodies,
then applies targeted inpainting/enhancement to those regions.
"""
import logging
import random
import re
import time
from typing import TYPE_CHECKING, Any, Optional, Callable
import numpy as np
import torch
if TYPE_CHECKING:
from src.Core.PipelineContext import PipelineContext
from src.Core.AbstractModel import AbstractModel
class Adetailer:
"""Automatic face and body detailing processor.
Uses YOLO detection and SAM segmentation to identify regions
of interest, then applies targeted inpainting enhancement.
"""
# Default settings
DEFAULT_GUIDE_SIZE = 512
DEFAULT_MAX_SIZE = 768
DEFAULT_STEPS = 20
DEFAULT_CFG = 6.5
DEFAULT_DENOISE = 0.5
DEFAULT_SCHEDULER = "karras"
DEFAULT_POSITIVE_PROMPT = "royal, detailed, magnificient, beautiful, seducing"
@classmethod
def _runtime_profile(cls, ctx: "PipelineContext", model: "AbstractModel") -> dict[str, Any]:
is_flux = getattr(model.capabilities, "is_flux", False)
is_flux2 = getattr(model.capabilities, "is_flux2", False)
is_sdxl = getattr(model.capabilities, "uses_dual_clip", False)
profile = {
"is_flux": is_flux,
"is_flux2": is_flux2,
"is_sdxl": is_sdxl,
"guide_size": cls.DEFAULT_GUIDE_SIZE,
"max_size": cls.DEFAULT_MAX_SIZE,
"steps": cls.DEFAULT_STEPS,
"cfg": cls.DEFAULT_CFG,
"denoise": cls.DEFAULT_DENOISE,
"scheduler": cls.DEFAULT_SCHEDULER,
"body_crop_factor": 2.0,
"face_crop_factor": 2.0,
}
if is_sdxl:
profile.update(
guide_size=512,
max_size=768,
steps=8,
cfg=cls.DEFAULT_CFG,
denoise=0.35,
scheduler=ctx.sampling.scheduler,
body_crop_factor=1.4,
face_crop_factor=1.6,
)
elif is_flux2:
profile.update(steps=6, cfg=1.0)
elif is_flux:
profile.update(steps=20, cfg=1.0)
return profile
@classmethod
def apply(
cls,
image: torch.Tensor,
ctx: "PipelineContext",
model: "AbstractModel",
positive: Any = None,
negative: Any = None,
callback: Optional[Callable] = None,
) -> tuple[torch.Tensor, list[dict]]:
"""Apply automatic face and body enhancement.
Args:
image: Input image tensor [B, H, W, C] or [H, W, C]
ctx: Pipeline context with configuration
model: The loaded model instance
positive: Optional positive conditioning (uses default if not provided)
negative: Negative conditioning from original generation
callback: Optional callback for live previews
Returns:
Tuple of (enhanced_image, list_of_saved_intermediate_images_metadata)
"""
logger = logging.getLogger(__name__)
saved_images = []
try:
# Ensure image has batch dimension
if image.dim() == 3:
image = image.unsqueeze(0)
# Import required modules
from src.AutoDetailer import SAM, SEGS, ADetailer, bbox
from src.clip import Clip
from src.FileManaging import ImageSaver
from src.AutoHDR import ahdr
# Load detection and segmentation models
samloader = SAM.SAMLoader()
sam_model = samloader.load_model(
model_name="sam_vit_b_01ec64.pth",
device_mode="AUTO"
)[0]
# Load YOLO detector for person/body
detector_provider = bbox.UltralyticsDetectorProvider()
body_detector = detector_provider.doit(model_name="person_yolov8m-seg.pt")[0]
# Use original positive conditioning if provided (preserves SDXL pooled_output
# and semantic context). Otherwise, re-encode from user's actual prompt.
cliptextencode = Clip.CLIPTextEncode()
if positive is not None:
adetailer_positive = positive
else:
# Fall back to user's prompt from context for semantic consistency
prompt_text = ctx.prompt if isinstance(ctx.prompt, str) else str(ctx.prompt)
adetailer_positive = cliptextencode.encode(
text=prompt_text,
clip=model.clip,
)[0]
# Initialize processors
bbox_detector = bbox.BboxDetectorForEach()
sam_detector = SAM.SAMDetectorCombined()
segs_mask = SEGS.SegsBitwiseAndMask()
detailer = ADetailer.DetailerForEachTest()
saveimage = ImageSaver.SaveImage()
hdr = ahdr.HDREffects()
profile = cls._runtime_profile(ctx, model)
# ===== BODY PASS =====
# Detect body regions
body_segs = bbox_detector.doit(
threshold=0.5,
dilation=10,
crop_factor=profile["body_crop_factor"],
drop_size=10,
labels="all",
bbox_detector=body_detector,
image=image,
)
# Apply SAM for precise segmentation
sam_result = sam_detector.doit(
detection_hint="center-1",
dilation=0,
threshold=0.93,
bbox_expansion=0,
mask_hint_threshold=0.7,
mask_hint_use_negative="False",
sam_model=sam_model,
segs=body_segs,
image=image,
)
if sam_result is None:
logger.info("Adetailer: No body regions detected")
return image[0] if image.shape[0] == 1 else image, saved_images
# Combine segmentation masks
combined_segs = segs_mask.doit(
segs=body_segs,
mask=sam_result[0],
)
# Apply body enhancement
body_seed = random.randint(1, 2**63 - 1)
body_start = time.perf_counter()
body_result = detailer.doit(
guide_size=profile["guide_size"],
guide_size_for=False,
max_size=profile["max_size"],
seed=body_seed,
steps=profile["steps"],
cfg=profile["cfg"],
sampler_name=ctx.sampling.sampler,
scheduler=profile["scheduler"],
denoise=profile["denoise"],
feather=5,
noise_mask=True,
force_inpaint=True,
wildcard="",
cycle=1,
inpaint_model=False,
noise_mask_feather=0,
image=image,
segs=combined_segs[0],
model=model.model,
clip=model.clip,
vae=model.vae,
positive=adetailer_positive,
negative=negative,
pipeline=True,
callback=callback,
)
logger.info(
"Adetailer body pass: guide=%s max=%s steps=%s scheduler=%s denoise=%s elapsed=%.2fs",
profile["guide_size"],
profile["max_size"],
profile["steps"],
profile["scheduler"],
profile["denoise"],
time.perf_counter() - body_start,
)
# Extract enhanced body image
body_image = body_result[0]
body_seed_str = cls._extract_seed(body_result, body_seed)
# Apply HDR if enabled
if ctx.generation.autohdr:
try:
hdr_result = hdr.apply_hdr2(body_image)
body_image = hdr_result[0] if isinstance(hdr_result, (tuple, list)) else hdr_result
except Exception:
pass
# Save body-enhanced image
body_meta = cls._build_metadata(ctx, body_seed_str, "body")
# Update meta with actual steps/cfg used
body_meta["steps"] = str(profile["steps"])
body_meta["cfg"] = str(profile["cfg"])
saved_body = saveimage.save_images(
filename_prefix="LD-body",
images=body_image,
prompt=ctx.prompt if isinstance(ctx.prompt, str) else str(ctx.prompt),
extra_pnginfo=body_meta,
)
saved_images.append(saved_body)
# ===== FACE PASS =====
# Check for interrupt before starting the next pass
from src.user import app_instance
app = getattr(app_instance, "app", None)
if app and getattr(app, "interrupt_flag", False):
logger.info("Adetailer: Interrupt requested, skipping Face pass")
return body_image[0] if body_image.shape[0] == 1 else body_image, saved_images
# Load face detector
face_detector = detector_provider.doit(model_name="face_yolov9c.pt")[0]
# Detect face regions on the body-enhanced image
face_segs = bbox_detector.doit(
threshold=0.5,
dilation=10,
crop_factor=profile["face_crop_factor"],
drop_size=10,
labels="all",
bbox_detector=face_detector,
image=body_image,
)
# Apply SAM for face segmentation
face_sam_result = sam_detector.doit(
detection_hint="center-1",
dilation=0,
threshold=0.93,
bbox_expansion=0,
mask_hint_threshold=0.7,
mask_hint_use_negative="False",
sam_model=sam_model,
segs=face_segs,
image=body_image,
)
if face_sam_result is None:
logger.info("Adetailer: No face regions detected")
return body_image[0] if body_image.shape[0] == 1 else body_image, saved_images
# Combine face segmentation masks
face_combined_segs = segs_mask.doit(
segs=face_segs,
mask=face_sam_result[0],
)
# Apply face enhancement
face_seed = random.randint(1, 2**63 - 1)
face_start = time.perf_counter()
face_result = detailer.doit(
guide_size=profile["guide_size"],
guide_size_for=False,
max_size=profile["max_size"],
seed=face_seed,
steps=profile["steps"],
cfg=profile["cfg"],
sampler_name=ctx.sampling.sampler,
scheduler=profile["scheduler"],
denoise=profile["denoise"],
feather=5,
noise_mask=True,
force_inpaint=True,
wildcard="",
cycle=1,
inpaint_model=False,
noise_mask_feather=0,
image=body_image,
segs=face_combined_segs[0],
model=model.model,
clip=model.clip,
vae=model.vae,
positive=adetailer_positive,
negative=negative,
pipeline=True,
callback=callback,
)
logger.info(
"Adetailer face pass: guide=%s max=%s steps=%s scheduler=%s denoise=%s elapsed=%.2fs",
profile["guide_size"],
profile["max_size"],
profile["steps"],
profile["scheduler"],
profile["denoise"],
time.perf_counter() - face_start,
)
# Extract final enhanced image
final_image = face_result[0]
face_seed_str = cls._extract_seed(face_result, face_seed)
# Apply HDR if enabled
if ctx.generation.autohdr:
try:
hdr_result = hdr.apply_hdr2(final_image)
final_image = hdr_result[0] if isinstance(hdr_result, (tuple, list)) else hdr_result
except Exception:
pass
# Save face-enhanced (final) image
face_meta = cls._build_metadata(ctx, face_seed_str, "head")
face_meta["steps"] = str(profile["steps"])
face_meta["cfg"] = str(profile["cfg"])
saved_face = saveimage.save_images(
filename_prefix="LD-head",
images=final_image,
prompt=ctx.prompt if isinstance(ctx.prompt, str) else str(ctx.prompt),
extra_pnginfo=face_meta,
)
saved_images.append(saved_face)
logger.info("Adetailer: completed body and face enhancement")
# Return final image (remove batch dim if it was added)
return final_image[0] if final_image.shape[0] == 1 else final_image, saved_images
except Exception as e:
logger.exception(f"Adetailer failed: {e}")
# Return original image on failure
return image[0] if image.dim() == 4 and image.shape[0] == 1 else image, saved_images
@classmethod
def _extract_seed(cls, result: Any, fallback_seed: int) -> str:
"""Extract seed from detailer result safely.
Args:
result: Result from detailer (may be tuple with seed)
fallback_seed: Seed to use if extraction fails
Returns:
String representation of the seed
"""
try:
if isinstance(result, (list, tuple)) and len(result) > 1:
candidate = result[1]
if isinstance(candidate, int):
return str(candidate)
if isinstance(candidate, float) and float(candidate).is_integer():
return str(int(candidate))
if isinstance(candidate, str):
s = candidate.strip()
if re.fullmatch(r"-?\d+", s):
return s
m = re.search(r"\d{4,}", s)
if m:
return m.group(0)
if isinstance(candidate, np.ndarray) and candidate.size == 1:
return str(int(candidate.item()))
if isinstance(candidate, torch.Tensor) and candidate.numel() == 1:
return str(int(candidate.item()))
except Exception:
pass
return str(fallback_seed)
@classmethod
def _build_metadata(
cls,
ctx: "PipelineContext",
seed: str,
pass_type: str,
) -> dict:
"""Build metadata dictionary for saved images.
Args:
ctx: Pipeline context
seed: Seed used for this pass
pass_type: Type of enhancement pass ('body' or 'head')
Returns:
Metadata dictionary
"""
return {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"prompt": ctx.prompt if isinstance(ctx.prompt, str) else str(ctx.prompt),
"negative_prompt": ctx.negative_prompt if isinstance(ctx.negative_prompt, str) else str(ctx.negative_prompt),
"seed": seed,
"sampler": ctx.sampling.sampler,
"steps": str(cls.DEFAULT_STEPS),
"cfg": str(cls.DEFAULT_CFG),
"scheduler": cls.DEFAULT_SCHEDULER,
"denoise": str(cls.DEFAULT_DENOISE),
"width": str(ctx.generation.width),
"height": str(ctx.generation.height),
"batch_size": str(1),
"adetailer": "True",
"adetailer_pass": pass_type,
}
@classmethod
def is_enabled(cls, ctx: "PipelineContext") -> bool:
"""Check if Adetailer should be applied based on context.
Args:
ctx: Pipeline context
Returns:
True if Adetailer should be applied
"""
return ctx.features.adetailer