Spaces:
Running
Running
| """Synthetic training pair generator. | |
| Creates (input, conditioning, mask, target) tuples for ControlNet fine-tuning. | |
| Pipeline: FFHQ image -> extract landmarks -> random FFD manipulation -> | |
| generate conditioning + mask -> apply clinical augmentation to input. | |
| Augmentations are applied to INPUT only, never to target (ground truth). | |
| """ | |
| from __future__ import annotations | |
| from collections.abc import Iterator | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| from landmarkdiff.conditioning import generate_conditioning | |
| from landmarkdiff.landmarks import extract_landmarks, render_landmark_image | |
| from landmarkdiff.manipulation import ( | |
| PROCEDURE_LANDMARKS, | |
| apply_procedure_preset, | |
| ) | |
| from landmarkdiff.masking import generate_surgical_mask | |
| from landmarkdiff.synthetic.augmentation import apply_clinical_augmentation | |
| from landmarkdiff.synthetic.tps_warp import warp_image_tps | |
| class TrainingPair: | |
| """A single training sample for ControlNet fine-tuning.""" | |
| input_image: np.ndarray # augmented input (512x512 BGR) | |
| target_image: np.ndarray # clean target (512x512 BGR) — TPS-warped original | |
| conditioning: np.ndarray # landmark rendering (512x512 BGR) | |
| canny: np.ndarray # canny edge map (512x512 grayscale) | |
| mask: np.ndarray # feathered surgical mask (512x512 float32) | |
| procedure: str | |
| intensity: float | |
| PROCEDURES = list(PROCEDURE_LANDMARKS.keys()) | |
| def generate_pair( | |
| image: np.ndarray, | |
| procedure: str | None = None, | |
| intensity: float | None = None, | |
| target_size: int = 512, | |
| rng: np.random.Generator | None = None, | |
| ) -> TrainingPair | None: | |
| """Generate a single training pair from a face image. | |
| Args: | |
| image: BGR input image (any size). | |
| procedure: Procedure type (random if None). | |
| intensity: Manipulation intensity 0-100 (random 30-90 if None). | |
| target_size: Output resolution. | |
| rng: Random number generator. | |
| Returns: | |
| TrainingPair or None if face detection fails. | |
| """ | |
| rng = rng or np.random.default_rng() | |
| # Resize to target | |
| resized = cv2.resize(image, (target_size, target_size)) | |
| # Extract landmarks | |
| face = extract_landmarks(resized) | |
| if face is None: | |
| return None | |
| # Random procedure and intensity if not specified | |
| if procedure is None: | |
| procedure = rng.choice(PROCEDURES) | |
| if intensity is None: | |
| intensity = float(rng.uniform(30, 90)) | |
| # Manipulate landmarks | |
| manipulated = apply_procedure_preset(face, procedure, intensity, target_size) | |
| # Generate conditioning from manipulated landmarks | |
| landmark_img = render_landmark_image(manipulated, target_size, target_size) | |
| _, canny, _ = generate_conditioning(manipulated, target_size, target_size) | |
| # Generate mask | |
| mask = generate_surgical_mask(face, procedure, target_size, target_size) | |
| # Generate target: TPS warp the original image to match manipulated landmarks | |
| src_px = face.pixel_coords | |
| dst_px = manipulated.pixel_coords | |
| target = warp_image_tps(resized, src_px, dst_px) | |
| # Apply clinical augmentation to INPUT only (never target) | |
| augmented_input = apply_clinical_augmentation(resized, rng=rng) | |
| return TrainingPair( | |
| input_image=augmented_input, | |
| target_image=target, | |
| conditioning=landmark_img, | |
| canny=canny, | |
| mask=mask, | |
| procedure=procedure, | |
| intensity=intensity, | |
| ) | |
| def generate_pairs_from_directory( | |
| image_dir: str | Path, | |
| num_pairs: int = 1000, | |
| target_size: int = 512, | |
| seed: int = 42, | |
| quality_check: bool = True, | |
| min_quality: float = 45.0, | |
| ) -> Iterator[TrainingPair]: | |
| """Generate training pairs from a directory of face images. | |
| Args: | |
| image_dir: Directory containing face images. | |
| num_pairs: Total number of pairs to generate. | |
| target_size: Output resolution. | |
| seed: Random seed. | |
| quality_check: Run face verifier quality check on source images. | |
| min_quality: Minimum quality score to use image (0-100). | |
| Yields: | |
| TrainingPair instances. | |
| """ | |
| rng = np.random.default_rng(seed) | |
| image_dir = Path(image_dir) | |
| extensions = {".jpg", ".jpeg", ".png", ".webp"} | |
| image_files = sorted( | |
| f for f in image_dir.iterdir() | |
| if f.suffix.lower() in extensions | |
| ) | |
| if not image_files: | |
| raise FileNotFoundError(f"No images found in {image_dir}") | |
| # Optional quality pre-filter | |
| _quality_cache: dict[str, float] = {} | |
| quality_rejects = 0 | |
| generated = 0 | |
| consecutive_failures = 0 | |
| idx = 0 | |
| while generated < num_pairs: | |
| # Cycle through images | |
| img_path = image_files[idx % len(image_files)] | |
| idx += 1 | |
| image = cv2.imread(str(img_path)) | |
| if image is None: | |
| consecutive_failures += 1 | |
| if consecutive_failures > len(image_files): | |
| print(f"Warning: {consecutive_failures} consecutive failures, stopping early") | |
| break | |
| continue | |
| # Quality gate: reject low-quality source images before pair generation | |
| if quality_check: | |
| cache_key = str(img_path) | |
| if cache_key not in _quality_cache: | |
| try: | |
| from landmarkdiff.face_verifier import analyze_distortions | |
| resized = cv2.resize(image, (target_size, target_size)) | |
| report = analyze_distortions(resized) | |
| _quality_cache[cache_key] = report.quality_score | |
| except Exception: | |
| _quality_cache[cache_key] = 100.0 # Can't check — allow through | |
| if _quality_cache[cache_key] < min_quality: | |
| quality_rejects += 1 | |
| if quality_rejects % 100 == 0: | |
| print(f" Quality filter: {quality_rejects} images rejected so far") | |
| consecutive_failures += 1 | |
| if consecutive_failures > len(image_files): | |
| break | |
| continue | |
| pair = generate_pair(image, target_size=target_size, rng=rng) | |
| if pair is not None: | |
| yield pair | |
| generated += 1 | |
| consecutive_failures = 0 | |
| else: | |
| consecutive_failures += 1 | |
| if consecutive_failures > len(image_files): | |
| print(f"Warning: {consecutive_failures} consecutive failures, stopping early") | |
| break | |
| if quality_rejects > 0: | |
| print(f"Quality filter: rejected {quality_rejects} low-quality source images") | |
| def save_pair(pair: TrainingPair, output_dir: Path, index: int) -> None: | |
| """Save a training pair to disk.""" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| prefix = f"{index:06d}" | |
| cv2.imwrite(str(output_dir / f"{prefix}_input.png"), pair.input_image) | |
| cv2.imwrite(str(output_dir / f"{prefix}_target.png"), pair.target_image) | |
| cv2.imwrite(str(output_dir / f"{prefix}_conditioning.png"), pair.conditioning) | |
| cv2.imwrite(str(output_dir / f"{prefix}_canny.png"), pair.canny) | |
| cv2.imwrite(str(output_dir / f"{prefix}_mask.png"), (pair.mask * 255).astype(np.uint8)) | |