"""Post-processing: CodeFormer/GFPGAN face restore, Real-ESRGAN bg, Laplacian blend, sharpening, histogram matching, ArcFace identity gate. """ from __future__ import annotations import cv2 import numpy as np def laplacian_pyramid_blend( source: np.ndarray, target: np.ndarray, mask: np.ndarray, levels: int = 6, ) -> np.ndarray: """Laplacian pyramid blend - kills the 'pasted on' look from alpha blending.""" # Ensure same size h, w = target.shape[:2] source = cv2.resize(source, (w, h)) if source.shape[:2] != (h, w) else source # Normalize mask mask_f = mask.astype(np.float32) if mask_f.max() > 1.0: mask_f = mask_f / 255.0 if mask_f.ndim == 2: mask_3ch = np.stack([mask_f] * 3, axis=-1) else: mask_3ch = mask_f # Make dimensions divisible by 2^levels factor = 2 ** levels new_h = (h + factor - 1) // factor * factor new_w = (w + factor - 1) // factor * factor if new_h != h or new_w != w: source = cv2.resize(source, (new_w, new_h)) target = cv2.resize(target, (new_w, new_h)) mask_3ch = cv2.resize(mask_3ch, (new_w, new_h)) src_f = source.astype(np.float32) tgt_f = target.astype(np.float32) # Build Gaussian pyramids for the mask mask_pyr = [mask_3ch] for _ in range(levels): mask_pyr.append(cv2.pyrDown(mask_pyr[-1])) # Build Laplacian pyramids for source and target src_lap = _build_laplacian_pyramid(src_f, levels) tgt_lap = _build_laplacian_pyramid(tgt_f, levels) # Blend each level using the mask at that resolution blended_lap = [] for i in range(levels + 1): sl = src_lap[i] tl = tgt_lap[i] ml = mask_pyr[i] # Resize mask to match level shape if needed if ml.shape[:2] != sl.shape[:2]: ml = cv2.resize(ml, (sl.shape[1], sl.shape[0])) blended = sl * ml + tl * (1.0 - ml) blended_lap.append(blended) # Reconstruct from blended Laplacian result = _reconstruct_from_laplacian(blended_lap) # Crop back to original size result = result[:h, :w] return np.clip(result, 0, 255).astype(np.uint8) def _build_laplacian_pyramid( image: np.ndarray, levels: int, ) -> list[np.ndarray]: """Build Laplacian pyramid from an image.""" gaussian = [image.copy()] for _ in range(levels): gaussian.append(cv2.pyrDown(gaussian[-1])) laplacian = [] for i in range(levels): upsampled = cv2.pyrUp(gaussian[i + 1]) # Match sizes (pyrUp can add a pixel) gh, gw = gaussian[i].shape[:2] upsampled = upsampled[:gh, :gw] laplacian.append(gaussian[i] - upsampled) laplacian.append(gaussian[-1]) # coarsest level return laplacian def _reconstruct_from_laplacian(pyramid: list[np.ndarray]) -> np.ndarray: """Reconstruct image from Laplacian pyramid.""" image = pyramid[-1].copy() for i in range(len(pyramid) - 2, -1, -1): image = cv2.pyrUp(image) lh, lw = pyramid[i].shape[:2] image = image[:lh, :lw] image = image + pyramid[i] return image def frequency_aware_sharpen( image: np.ndarray, strength: float = 0.3, radius: int = 3, ) -> np.ndarray: """Unsharp mask on LAB luminance only - sharpens skin texture without color fringe.""" lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32) l_channel = lab[:, :, 0] # Unsharp mask on luminance only ksize = radius * 2 + 1 blurred = cv2.GaussianBlur(l_channel, (ksize, ksize), 0) sharpened = l_channel + strength * (l_channel - blurred) lab[:, :, 0] = np.clip(sharpened, 0, 255) return cv2.cvtColor(lab.astype(np.uint8), cv2.COLOR_LAB2BGR) def restore_face_gfpgan( image: np.ndarray, upscale: int = 1, ) -> np.ndarray: """GFPGAN face restore. Returns original if not installed.""" try: from gfpgan import GFPGANer except ImportError: return image try: restorer = GFPGANer( model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth", upscale=upscale, arch="clean", channel_multiplier=2, bg_upsampler=None, ) _, _, restored = restorer.enhance( image, has_aligned=False, only_center_face=True, paste_back=True, ) if restored is not None: return restored except Exception: pass return image def restore_face_codeformer( image: np.ndarray, fidelity: float = 0.7, upscale: int = 1, ) -> np.ndarray: """CodeFormer face restore. fidelity: 0=quality, 1=identity. Returns original if not installed.""" try: from codeformer.basicsr.utils import img2tensor, tensor2img from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper from codeformer.basicsr.utils.download_util import load_file_from_url import torch from torchvision.transforms.functional import normalize as tv_normalize except ImportError: return image try: from codeformer.inference_codeformer import set_realesrgan as _unused # noqa: F401 from codeformer.basicsr.archs.codeformer_arch import CodeFormer as CodeFormerArch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = CodeFormerArch( dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=["32", "64", "128", "256"], ).to(device) ckpt_path = load_file_from_url( url="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth", model_dir="weights/CodeFormer", progress=True, ) checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) model.load_state_dict(checkpoint["params_ema"]) model.eval() face_helper = FaceRestoreHelper( upscale, face_size=512, crop_ratio=(1, 1), det_model="retinaface_resnet50", save_ext="png", device=device, ) face_helper.read_image(image) face_helper.get_face_landmarks_5(only_center_face=True) face_helper.align_warp_face() for cropped_face in face_helper.cropped_faces: face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True) tv_normalize(face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) face_t = face_t.unsqueeze(0).to(device) with torch.no_grad(): output = model(face_t, w=fidelity, adain=True)[0] restored = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) restored = restored.astype(np.uint8) face_helper.add_restored_face(restored) face_helper.get_inverse_affine(None) restored_img = face_helper.paste_faces_to_image() if restored_img is not None: return restored_img except Exception: pass return image def enhance_background_realesrgan( image: np.ndarray, mask: np.ndarray, outscale: int = 2, ) -> np.ndarray: """Real-ESRGAN on background only (outside mask). Returns original if not installed.""" try: from realesrgan import RealESRGANer from basicsr.archs.rrdbnet_arch import RRDBNet import torch except ImportError: return image try: model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) upsampler = RealESRGANer( scale=4, model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", model=model, tile=400, tile_pad=10, pre_pad=0, half=torch.cuda.is_available(), ) enhanced, _ = upsampler.enhance(image, outscale=outscale) # Downscale back to original size h, w = image.shape[:2] enhanced = cv2.resize(enhanced, (w, h), interpolation=cv2.INTER_LANCZOS4) # Only apply enhancement to background (outside mask) mask_f = mask.astype(np.float32) if mask_f.max() > 1.0: mask_f /= 255.0 if mask_f.ndim == 2: mask_3ch = np.stack([mask_f] * 3, axis=-1) else: mask_3ch = mask_f # Keep face region from original, use enhanced for background result = ( image.astype(np.float32) * mask_3ch + enhanced.astype(np.float32) * (1.0 - mask_3ch) ).astype(np.uint8) return result except Exception: pass return image def verify_identity_arcface( original: np.ndarray, result: np.ndarray, threshold: float = 0.6, ) -> dict: """ArcFace cosine similarity check. Flags if output drifted from input identity.""" try: from insightface.app import FaceAnalysis except ImportError: return { "similarity": -1.0, "passed": True, "message": "InsightFace not installed - identity check skipped", } try: app = FaceAnalysis( name="buffalo_l", providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) app.prepare(ctx_id=0 if _has_cuda() else -1, det_size=(320, 320)) orig_faces = app.get(original) result_faces = app.get(result) if not orig_faces or not result_faces: return { "similarity": -1.0, "passed": True, "message": "Could not detect face in one/both images - check skipped", } orig_emb = orig_faces[0].embedding result_emb = result_faces[0].embedding sim = float(np.dot(orig_emb, result_emb) / ( np.linalg.norm(orig_emb) * np.linalg.norm(result_emb) + 1e-8 )) sim = float(np.clip(sim, 0, 1)) passed = sim >= threshold if passed: msg = f"Identity preserved (similarity={sim:.3f})" else: msg = f"WARNING: Identity drift detected (similarity={sim:.3f} < {threshold})" return {"similarity": sim, "passed": passed, "message": msg} except Exception as e: return { "similarity": -1.0, "passed": True, "message": f"Identity check failed: {e}", } def _has_cuda() -> bool: try: import torch return torch.cuda.is_available() except ImportError: return False def histogram_match_skin( source: np.ndarray, reference: np.ndarray, mask: np.ndarray, ) -> np.ndarray: """CDF-based histogram matching in LAB space. Better than mean/std for skin.""" mask_bool = mask > 0.3 if mask.dtype == np.float32 else mask > 76 if not np.any(mask_bool): return source result = source.copy() src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32) ref_lab = cv2.cvtColor(reference, cv2.COLOR_BGR2LAB).astype(np.float32) for ch in range(3): src_vals = src_lab[:, :, ch][mask_bool] ref_vals = ref_lab[:, :, ch][mask_bool] if len(src_vals) == 0 or len(ref_vals) == 0: continue # CDF matching src_sorted = np.sort(src_vals) ref_sorted = np.sort(ref_vals) # Interpolate reference CDF to match source length src_cdf = np.linspace(0, 1, len(src_sorted)) ref_cdf = np.linspace(0, 1, len(ref_sorted)) # Map source values through reference distribution mapping = np.interp(src_cdf, ref_cdf, ref_sorted) # Create lookup from source intensity to matched intensity src_flat = src_lab[:, :, ch].ravel() matched = np.interp(src_flat, src_sorted, mapping) matched_2d = matched.reshape(src_lab.shape[:2]) # Apply only in mask region src_lab[:, :, ch] = np.where(mask_bool, matched_2d, src_lab[:, :, ch]) result_lab = np.clip(src_lab, 0, 255).astype(np.uint8) return cv2.cvtColor(result_lab, cv2.COLOR_LAB2BGR) def full_postprocess( generated: np.ndarray, original: np.ndarray, mask: np.ndarray, restore_mode: str = "codeformer", codeformer_fidelity: float = 0.7, use_realesrgan: bool = True, use_laplacian_blend: bool = True, sharpen_strength: float = 0.25, verify_identity: bool = True, identity_threshold: float = 0.6, ) -> dict: """Full pipeline: restore -> bg enhance -> histogram match -> sharpen -> blend -> identity check.""" result = generated.copy() restore_used = "none" # Step 1: Neural face restoration (CodeFormer > GFPGAN > skip) if restore_mode == "codeformer": restored = restore_face_codeformer(result, fidelity=codeformer_fidelity) if restored is not result: result = restored restore_used = "codeformer" else: # CodeFormer unavailable, fall back to GFPGAN result = restore_face_gfpgan(result) restore_used = "gfpgan" if result is not generated else "none" elif restore_mode == "gfpgan": restored = restore_face_gfpgan(result) if restored is not result: result = restored restore_used = "gfpgan" # Step 2: Neural background enhancement if use_realesrgan: result = enhance_background_realesrgan(result, mask) # Step 3: Skin tone histogram matching (classical) result = histogram_match_skin(result, original, mask) # Step 4: Sharpen texture (classical) if sharpen_strength > 0: result = frequency_aware_sharpen(result, strength=sharpen_strength) # Step 5: Blend into original (classical) if use_laplacian_blend: composited = laplacian_pyramid_blend(result, original, mask) else: mask_f = mask.astype(np.float32) if mask_f.max() > 1.0: mask_f /= 255.0 if mask_f.ndim == 2: mask_3ch = np.stack([mask_f] * 3, axis=-1) else: mask_3ch = mask_f composited = ( result.astype(np.float32) * mask_3ch + original.astype(np.float32) * (1.0 - mask_3ch) ).astype(np.uint8) # Step 6: Neural identity verification identity_check = {"similarity": -1.0, "passed": True, "message": "skipped"} if verify_identity: identity_check = verify_identity_arcface( original, composited, threshold=identity_threshold, ) return { "image": composited, "identity_check": identity_check, "restore_used": restore_used, }