LandmarkDiff / landmarkdiff /inference.py
dreamlessx's picture
Upload landmarkdiff/inference.py with huggingface_hub
996d0fe verified
raw
history blame
19 kB
"""Inference pipeline for surgical outcome prediction.
Modes:
1. ControlNet: CrucibleAI/ControlNetMediaPipeFace + SD1.5 (HF auth + GPU)
2. ControlNet + IP-Adapter: ControlNet w/ identity preservation
3. Img2Img: SD1.5 img2img with mask compositing (MPS ok, no auth)
4. TPS-only: geometric warp, no diffusion, instant
Works on MPS (Apple Silicon), CUDA, and CPU.
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
import torch
from PIL import Image
from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks, render_landmark_image
from landmarkdiff.conditioning import generate_conditioning
from landmarkdiff.manipulation import apply_procedure_preset
from landmarkdiff.masking import generate_surgical_mask, mask_to_3channel
from landmarkdiff.synthetic.tps_warp import warp_image_tps
def get_device() -> torch.device:
if torch.backends.mps.is_available():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def numpy_to_pil(arr: np.ndarray) -> Image.Image:
if len(arr.shape) == 2:
return Image.fromarray(arr, mode="L")
return Image.fromarray(arr[:, :, ::-1])
def pil_to_numpy(img: Image.Image) -> np.ndarray:
arr = np.array(img)
if len(arr.shape) == 3 and arr.shape[2] == 3:
return arr[:, :, ::-1].copy()
return arr
PROCEDURE_PROMPTS: dict[str, str] = {
"rhinoplasty": (
"clinical photograph, patient face, natural refined nose, smooth nasal bridge, "
"realistic skin pores and texture, sharp focus, studio lighting, "
"DSLR quality, natural skin color"
),
"blepharoplasty": (
"clinical photograph, patient face, natural eyelids, smooth periorbital area, "
"realistic skin pores and texture, sharp focus, studio lighting, "
"DSLR quality, natural skin color"
),
"rhytidectomy": (
"clinical photograph, patient face, defined jawline, smooth facial contour, "
"realistic skin pores and texture, sharp focus, studio lighting, "
"DSLR quality, natural skin color"
),
"orthognathic": (
"clinical photograph, patient face, balanced jaw and chin proportions, "
"realistic skin pores and texture, sharp focus, studio lighting, "
"DSLR quality, natural skin color"
),
}
NEGATIVE_PROMPT = (
"painting, drawing, illustration, cartoon, anime, render, 3d, cgi, "
"blurry, distorted, deformed, disfigured, bad anatomy, bad proportions, "
"extra limbs, mutated, poorly drawn face, ugly, low quality, low resolution, "
"watermark, text, signature, duplicate, artifact, noise, overexposed, "
"plastic skin, waxy, smooth skin, airbrushed, oversaturated"
)
def mask_composite(
warped: np.ndarray,
original: np.ndarray,
mask: np.ndarray,
use_laplacian: bool = True,
) -> np.ndarray:
"""Blend warped region into original via Laplacian pyramid + LAB skin-tone match."""
mask_f = mask.astype(np.float32)
if mask_f.max() > 1.0:
mask_f = mask_f / 255.0
# Match color of warped region to original skin tone in LAB space
corrected = _match_skin_tone(warped, original, mask_f)
if use_laplacian:
try:
from landmarkdiff.postprocess import laplacian_pyramid_blend
return laplacian_pyramid_blend(corrected, original, mask_f)
except Exception:
pass
# Fallback: simple alpha blend
mask_3ch = mask_to_3channel(mask_f)
result = (
corrected.astype(np.float32) * mask_3ch
+ original.astype(np.float32) * (1.0 - mask_3ch)
).astype(np.uint8)
return result
def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""LAB-space color transfer so warped region matches original skin tone."""
mask_bool = mask > 0.3
if not np.any(mask_bool):
return source
src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
tgt_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
# match per-channel stats in masked region
for ch in range(3):
src_vals = src_lab[:, :, ch][mask_bool]
tgt_vals = tgt_lab[:, :, ch][mask_bool]
src_mean, src_std = np.mean(src_vals), np.std(src_vals) + 1e-6
tgt_mean, tgt_std = np.mean(tgt_vals), np.std(tgt_vals) + 1e-6
# shift+scale to match target distribution
src_lab[:, :, ch] = np.where(
mask_bool,
(src_lab[:, :, ch] - src_mean) * (tgt_std / src_std) + tgt_mean,
src_lab[:, :, ch],
)
src_lab = np.clip(src_lab, 0, 255)
return cv2.cvtColor(src_lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
def poisson_blend(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""Poisson blend - just delegates to mask_composite (more reliable)."""
return mask_composite(source, target, mask)
class LandmarkDiffPipeline:
"""Image -> landmarks -> manipulate -> generate."""
# Default IP-Adapter model for SD1.5 face identity
IP_ADAPTER_REPO = "h94/IP-Adapter"
IP_ADAPTER_SUBFOLDER = "models"
IP_ADAPTER_WEIGHT_NAME = "ip-adapter-plus-face_sd15.bin"
IP_ADAPTER_SCALE_DEFAULT = 0.6
def __init__(
self,
mode: str = "img2img",
controlnet_id: str = "CrucibleAI/ControlNetMediaPipeFace",
base_model_id: str | None = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
ip_adapter_scale: float = 0.6,
clinical_flags: Optional["ClinicalFlags"] = None,
):
self.mode = mode
self.device = device or get_device()
self.ip_adapter_scale = ip_adapter_scale
self.clinical_flags = clinical_flags
if self.device.type == "mps":
self.dtype = torch.float32
elif dtype:
self.dtype = dtype
else:
self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
if base_model_id:
self.base_model_id = base_model_id
elif mode in ("controlnet", "controlnet_ip"):
self.base_model_id = "runwayml/stable-diffusion-v1-5"
else:
self.base_model_id = "runwayml/stable-diffusion-v1-5"
self.controlnet_id = controlnet_id
self._pipe = None
self._ip_adapter_loaded = False
def load(self) -> None:
if self.mode == "tps":
print("TPS mode - no model to load")
return
if self.mode in ("controlnet", "controlnet_ip"):
self._load_controlnet()
if self.mode == "controlnet_ip":
self._load_ip_adapter()
else:
self._load_img2img()
def _load_controlnet(self) -> None:
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
DPMSolverMultistepScheduler,
)
print(f"Loading ControlNet from {self.controlnet_id}...")
controlnet = ControlNetModel.from_pretrained(
self.controlnet_id, subfolder="diffusion_sd15", torch_dtype=self.dtype,
)
print(f"Loading base model from {self.base_model_id}...")
self._pipe = StableDiffusionControlNetPipeline.from_pretrained(
self.base_model_id,
controlnet=controlnet,
torch_dtype=self.dtype,
safety_checker=None,
requires_safety_checker=False,
)
# DPM++ 2M Karras - better skin than UniPC
self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self._pipe.scheduler.config,
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
# FP32 VAE decode - prevents color banding on skin
if hasattr(self._pipe, "vae") and self._pipe.vae is not None:
self._pipe.vae.config.force_upcast = True
self._apply_device_optimizations()
def _load_ip_adapter(self) -> None:
"""Load IP-Adapter for identity preservation."""
if self._pipe is None:
raise RuntimeError("Base pipeline must be loaded before IP-Adapter")
try:
print(f"Loading IP-Adapter ({self.IP_ADAPTER_WEIGHT_NAME})...")
self._pipe.load_ip_adapter(
self.IP_ADAPTER_REPO,
subfolder=self.IP_ADAPTER_SUBFOLDER,
weight_name=self.IP_ADAPTER_WEIGHT_NAME,
)
self._pipe.set_ip_adapter_scale(self.ip_adapter_scale)
self._ip_adapter_loaded = True
print(f"IP-Adapter loaded (scale={self.ip_adapter_scale})")
except Exception as e:
print(f"WARNING: IP-Adapter load failed: {e}")
print("Falling back to ControlNet-only mode")
self._ip_adapter_loaded = False
def _load_img2img(self) -> None:
from diffusers import (
StableDiffusionImg2ImgPipeline,
DPMSolverMultistepScheduler,
)
print(f"Loading SD1.5 img2img from {self.base_model_id}...")
self._pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
self.base_model_id,
torch_dtype=self.dtype,
safety_checker=None,
requires_safety_checker=False,
)
self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self._pipe.scheduler.config
)
self._apply_device_optimizations()
def _apply_device_optimizations(self) -> None:
if self.device.type == "mps":
self._pipe = self._pipe.to(self.device)
self._pipe.enable_attention_slicing()
elif self.device.type == "cuda":
try:
self._pipe.enable_model_cpu_offload()
except Exception:
self._pipe = self._pipe.to(self.device)
else:
self._pipe.enable_sequential_cpu_offload()
print(f"Pipeline loaded on {self.device} ({self.dtype})")
@property
def is_loaded(self) -> bool:
return self._pipe is not None or self.mode == "tps"
def generate(
self,
image: np.ndarray,
procedure: str = "rhinoplasty",
intensity: float = 50.0,
num_inference_steps: int = 30,
guidance_scale: float = 9.0,
controlnet_conditioning_scale: float = 0.9,
strength: float = 0.5,
seed: Optional[int] = None,
clinical_flags: Optional["ClinicalFlags"] = None,
postprocess: bool = True,
use_gfpgan: bool = False,
) -> dict:
if not self.is_loaded:
raise RuntimeError("Pipeline not loaded. Call .load() first.")
flags = clinical_flags or self.clinical_flags
image_512 = cv2.resize(image, (512, 512))
face = extract_landmarks(image_512)
if face is None:
raise ValueError("No face detected in image.")
# face view angle for multi-view awareness
view_info = estimate_face_view(face)
manipulated = apply_procedure_preset(
face, procedure, intensity, image_size=512, clinical_flags=flags,
)
landmark_img = render_landmark_image(manipulated, 512, 512)
mask = generate_surgical_mask(
face, procedure, 512, 512, clinical_flags=flags,
)
generator = None
if seed is not None:
generator = torch.Generator(device="cpu").manual_seed(seed)
prompt = PROCEDURE_PROMPTS.get(procedure, "a photo of a person's face")
# TPS warp is always the geometric baseline
tps_warped = warp_image_tps(image_512, face.pixel_coords, manipulated.pixel_coords)
if self.mode == "tps":
raw_output = tps_warped
elif self.mode in ("controlnet", "controlnet_ip"):
ip_image = numpy_to_pil(image_512) if self._ip_adapter_loaded else None
raw_output = self._generate_controlnet(
image_512, landmark_img, prompt, num_inference_steps,
guidance_scale, controlnet_conditioning_scale, generator,
ip_adapter_image=ip_image,
)
else:
raw_output = self._generate_img2img(
tps_warped, mask, prompt, num_inference_steps,
guidance_scale, strength, generator,
)
# postprocess for photorealism
identity_check = None
restore_used = "none"
if postprocess and self.mode != "tps":
from landmarkdiff.postprocess import full_postprocess
pp_result = full_postprocess(
generated=raw_output,
original=image_512,
mask=mask,
restore_mode="codeformer" if use_gfpgan else "none",
use_realesrgan=use_gfpgan,
use_laplacian_blend=True,
sharpen_strength=0.25,
verify_identity=True,
)
composited = pp_result["image"]
identity_check = pp_result["identity_check"]
restore_used = pp_result["restore_used"]
else:
composited = mask_composite(raw_output, image_512, mask)
return {
"output": composited,
"output_raw": raw_output,
"output_tps": tps_warped,
"input": image_512,
"landmarks_original": face,
"landmarks_manipulated": manipulated,
"conditioning": landmark_img,
"mask": mask,
"procedure": procedure,
"intensity": intensity,
"device": str(self.device),
"mode": self.mode,
"view_info": view_info,
"ip_adapter_active": self._ip_adapter_loaded,
"identity_check": identity_check,
"restore_used": restore_used,
}
def _generate_controlnet(
self, image: np.ndarray, conditioning: np.ndarray,
prompt: str, steps: int, cfg: float, cn_scale: float,
generator: torch.Generator | None,
ip_adapter_image: Image.Image | None = None,
) -> np.ndarray:
kwargs = dict(
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
image=numpy_to_pil(conditioning), # control conditioning only
num_inference_steps=steps,
guidance_scale=cfg,
controlnet_conditioning_scale=cn_scale,
generator=generator,
)
if ip_adapter_image is not None and self._ip_adapter_loaded:
kwargs["ip_adapter_image"] = ip_adapter_image
result = self._pipe(**kwargs)
return pil_to_numpy(result.images[0])
def _generate_img2img(
self, image: np.ndarray, mask: np.ndarray,
prompt: str, steps: int, cfg: float, strength: float,
generator: torch.Generator | None,
) -> np.ndarray:
result = self._pipe(
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
image=numpy_to_pil(image),
num_inference_steps=steps,
guidance_scale=cfg,
strength=strength,
generator=generator,
)
return pil_to_numpy(result.images[0])
def estimate_face_view(face: FaceLandmarks) -> dict:
"""Yaw/pitch from nose-ear and forehead-chin distances. Returns view dict."""
coords = face.pixel_coords
nose_tip = coords[1]
left_ear = coords[234]
right_ear = coords[454]
forehead = coords[10]
chin = coords[152]
# Yaw: ratio of nose-to-ear distances (symmetric = 0 degrees)
left_dist = np.linalg.norm(nose_tip - left_ear)
right_dist = np.linalg.norm(nose_tip - right_ear)
total = left_dist + right_dist
if total < 1.0:
yaw = 0.0
else:
ratio = (right_dist - left_dist) / total
yaw = float(np.arcsin(np.clip(ratio, -1, 1)) * 180 / np.pi)
# Pitch: nose-to-chin vs forehead-to-nose vertical ratio
upper = np.linalg.norm(forehead - nose_tip)
lower = np.linalg.norm(nose_tip - chin)
if upper + lower < 1.0:
pitch = 0.0
else:
pitch_ratio = (lower - upper) / (upper + lower)
pitch = float(pitch_ratio * 45)
# Classify view
abs_yaw = abs(yaw)
if abs_yaw < 15:
view = "frontal"
elif abs_yaw < 45:
view = "three_quarter"
else:
view = "profile"
return {
"yaw": round(yaw, 1),
"pitch": round(pitch, 1),
"view": view,
"is_frontal": abs_yaw < 15,
"warning": "Side-view detected: results may be less accurate" if abs_yaw > 30 else None,
}
def run_inference(
image_path: str,
procedure: str = "rhinoplasty",
intensity: float = 50.0,
output_dir: str = "scripts/inference_output",
seed: int = 42,
mode: str = "img2img",
ip_adapter_scale: float = 0.6,
) -> None:
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
image = cv2.imread(image_path)
if image is None:
print(f"ERROR: Could not load {image_path}")
sys.exit(1)
pipe = LandmarkDiffPipeline(mode=mode, ip_adapter_scale=ip_adapter_scale)
pipe.load()
print(f"\nGenerating {procedure} prediction (intensity={intensity}, mode={mode})...")
result = pipe.generate(image, procedure=procedure, intensity=intensity, seed=seed)
cv2.imwrite(str(out / "input.png"), result["input"])
cv2.imwrite(str(out / "output.png"), result["output"])
cv2.imwrite(str(out / "output_raw.png"), result["output_raw"])
cv2.imwrite(str(out / "output_tps.png"), result["output_tps"])
cv2.imwrite(str(out / "conditioning.png"), result["conditioning"])
cv2.imwrite(str(out / "mask.png"), (result["mask"] * 255).astype(np.uint8))
comparison = np.hstack([result["input"], result["output_tps"], result["output"]])
cv2.imwrite(str(out / "comparison.png"), comparison)
view = result.get("view_info", {})
if view.get("warning"):
print(f"WARNING: {view['warning']}")
print(f"Face view: {view.get('view', 'unknown')} (yaw={view.get('yaw', 0)})")
print(f"Results saved to {out}/")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="LandmarkDiff inference")
parser.add_argument("image", help="Path to face image")
parser.add_argument("--procedure", default="rhinoplasty")
parser.add_argument("--intensity", type=float, default=50.0)
parser.add_argument("--output", default="scripts/inference_output")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--mode", default="img2img",
choices=["img2img", "controlnet", "controlnet_ip", "tps"],
)
parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
args = parser.parse_args()
run_inference(
args.image, args.procedure, args.intensity, args.output,
args.seed, args.mode, args.ip_adapter_scale,
)