"""Synthetic training pair generator. Creates (input, conditioning, mask, target) tuples for ControlNet fine-tuning. Pipeline: FFHQ image -> extract landmarks -> random FFD manipulation -> generate conditioning + mask -> apply clinical augmentation to input. Augmentations are applied to INPUT only, never to target (ground truth). """ from __future__ import annotations from collections.abc import Iterator from dataclasses import dataclass from pathlib import Path import cv2 import numpy as np from landmarkdiff.conditioning import generate_conditioning from landmarkdiff.landmarks import extract_landmarks, render_landmark_image from landmarkdiff.manipulation import ( PROCEDURE_LANDMARKS, apply_procedure_preset, ) from landmarkdiff.masking import generate_surgical_mask from landmarkdiff.synthetic.augmentation import apply_clinical_augmentation from landmarkdiff.synthetic.tps_warp import warp_image_tps @dataclass(frozen=True) class TrainingPair: """A single training sample for ControlNet fine-tuning.""" input_image: np.ndarray # augmented input (512x512 BGR) target_image: np.ndarray # clean target (512x512 BGR) — TPS-warped original conditioning: np.ndarray # landmark rendering (512x512 BGR) canny: np.ndarray # canny edge map (512x512 grayscale) mask: np.ndarray # feathered surgical mask (512x512 float32) procedure: str intensity: float PROCEDURES = list(PROCEDURE_LANDMARKS.keys()) def generate_pair( image: np.ndarray, procedure: str | None = None, intensity: float | None = None, target_size: int = 512, rng: np.random.Generator | None = None, ) -> TrainingPair | None: """Generate a single training pair from a face image. Args: image: BGR input image (any size). procedure: Procedure type (random if None). intensity: Manipulation intensity 0-100 (random 30-90 if None). target_size: Output resolution. rng: Random number generator. Returns: TrainingPair or None if face detection fails. """ rng = rng or np.random.default_rng() # Resize to target resized = cv2.resize(image, (target_size, target_size)) # Extract landmarks face = extract_landmarks(resized) if face is None: return None # Random procedure and intensity if not specified if procedure is None: procedure = rng.choice(PROCEDURES) if intensity is None: intensity = float(rng.uniform(30, 90)) # Manipulate landmarks manipulated = apply_procedure_preset(face, procedure, intensity, target_size) # Generate conditioning from manipulated landmarks landmark_img = render_landmark_image(manipulated, target_size, target_size) _, canny, _ = generate_conditioning(manipulated, target_size, target_size) # Generate mask mask = generate_surgical_mask(face, procedure, target_size, target_size) # Generate target: TPS warp the original image to match manipulated landmarks src_px = face.pixel_coords dst_px = manipulated.pixel_coords target = warp_image_tps(resized, src_px, dst_px) # Apply clinical augmentation to INPUT only (never target) augmented_input = apply_clinical_augmentation(resized, rng=rng) return TrainingPair( input_image=augmented_input, target_image=target, conditioning=landmark_img, canny=canny, mask=mask, procedure=procedure, intensity=intensity, ) def generate_pairs_from_directory( image_dir: str | Path, num_pairs: int = 1000, target_size: int = 512, seed: int = 42, quality_check: bool = True, min_quality: float = 45.0, ) -> Iterator[TrainingPair]: """Generate training pairs from a directory of face images. Args: image_dir: Directory containing face images. num_pairs: Total number of pairs to generate. target_size: Output resolution. seed: Random seed. quality_check: Run face verifier quality check on source images. min_quality: Minimum quality score to use image (0-100). Yields: TrainingPair instances. """ rng = np.random.default_rng(seed) image_dir = Path(image_dir) extensions = {".jpg", ".jpeg", ".png", ".webp"} image_files = sorted( f for f in image_dir.iterdir() if f.suffix.lower() in extensions ) if not image_files: raise FileNotFoundError(f"No images found in {image_dir}") # Optional quality pre-filter _quality_cache: dict[str, float] = {} quality_rejects = 0 generated = 0 consecutive_failures = 0 idx = 0 while generated < num_pairs: # Cycle through images img_path = image_files[idx % len(image_files)] idx += 1 image = cv2.imread(str(img_path)) if image is None: consecutive_failures += 1 if consecutive_failures > len(image_files): print(f"Warning: {consecutive_failures} consecutive failures, stopping early") break continue # Quality gate: reject low-quality source images before pair generation if quality_check: cache_key = str(img_path) if cache_key not in _quality_cache: try: from landmarkdiff.face_verifier import analyze_distortions resized = cv2.resize(image, (target_size, target_size)) report = analyze_distortions(resized) _quality_cache[cache_key] = report.quality_score except Exception: _quality_cache[cache_key] = 100.0 # Can't check — allow through if _quality_cache[cache_key] < min_quality: quality_rejects += 1 if quality_rejects % 100 == 0: print(f" Quality filter: {quality_rejects} images rejected so far") consecutive_failures += 1 if consecutive_failures > len(image_files): break continue pair = generate_pair(image, target_size=target_size, rng=rng) if pair is not None: yield pair generated += 1 consecutive_failures = 0 else: consecutive_failures += 1 if consecutive_failures > len(image_files): print(f"Warning: {consecutive_failures} consecutive failures, stopping early") break if quality_rejects > 0: print(f"Quality filter: rejected {quality_rejects} low-quality source images") def save_pair(pair: TrainingPair, output_dir: Path, index: int) -> None: """Save a training pair to disk.""" output_dir.mkdir(parents=True, exist_ok=True) prefix = f"{index:06d}" cv2.imwrite(str(output_dir / f"{prefix}_input.png"), pair.input_image) cv2.imwrite(str(output_dir / f"{prefix}_target.png"), pair.target_image) cv2.imwrite(str(output_dir / f"{prefix}_conditioning.png"), pair.conditioning) cv2.imwrite(str(output_dir / f"{prefix}_canny.png"), pair.canny) cv2.imwrite(str(output_dir / f"{prefix}_mask.png"), (pair.mask * 255).astype(np.uint8))