dreamlessx commited on
Commit
d2ca600
·
verified ·
1 Parent(s): b229662

Upload landmarkdiff/synthetic/pair_generator.py with huggingface_hub

Browse files
landmarkdiff/synthetic/pair_generator.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Synthetic pair generator for ControlNet fine-tuning.
2
+
3
+ FFHQ -> landmarks -> random FFD -> conditioning + mask -> augment input.
4
+ Augmentations on INPUT only, never target.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Iterator
12
+
13
+ import cv2
14
+ import numpy as np
15
+
16
+ from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks, render_landmark_image
17
+ from landmarkdiff.conditioning import generate_conditioning
18
+ from landmarkdiff.manipulation import (
19
+ PROCEDURE_LANDMARKS,
20
+ apply_procedure_preset,
21
+ )
22
+ from landmarkdiff.masking import generate_surgical_mask
23
+ from landmarkdiff.synthetic.augmentation import apply_clinical_augmentation
24
+ from landmarkdiff.synthetic.tps_warp import warp_image_tps, generate_random_warp
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class TrainingPair:
29
+ """A single training sample for ControlNet fine-tuning."""
30
+
31
+ input_image: np.ndarray # augmented input (512x512 BGR)
32
+ target_image: np.ndarray # clean target (512x512 BGR) - TPS-warped original
33
+ conditioning: np.ndarray # landmark rendering (512x512 BGR)
34
+ canny: np.ndarray # canny edge map (512x512 grayscale)
35
+ mask: np.ndarray # feathered surgical mask (512x512 float32)
36
+ procedure: str
37
+ intensity: float
38
+
39
+
40
+ PROCEDURES = ["rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic"]
41
+
42
+
43
+ def generate_pair(
44
+ image: np.ndarray,
45
+ procedure: str | None = None,
46
+ intensity: float | None = None,
47
+ target_size: int = 512,
48
+ rng: np.random.Generator | None = None,
49
+ ) -> TrainingPair | None:
50
+ """Generate a single training pair from a face image."""
51
+ rng = rng or np.random.default_rng()
52
+
53
+ # Resize to target
54
+ resized = cv2.resize(image, (target_size, target_size))
55
+
56
+ # Extract landmarks
57
+ face = extract_landmarks(resized)
58
+ if face is None:
59
+ return None
60
+
61
+ # Random procedure and intensity if not specified
62
+ if procedure is None:
63
+ procedure = rng.choice(PROCEDURES)
64
+ if intensity is None:
65
+ intensity = float(rng.uniform(30, 90))
66
+
67
+ # Manipulate landmarks
68
+ manipulated = apply_procedure_preset(face, procedure, intensity, target_size)
69
+
70
+ # Generate conditioning from manipulated landmarks
71
+ landmark_img = render_landmark_image(manipulated, target_size, target_size)
72
+ _, canny, _ = generate_conditioning(manipulated, target_size, target_size)
73
+
74
+ # Generate mask
75
+ mask = generate_surgical_mask(face, procedure, target_size, target_size)
76
+
77
+ # Generate target: TPS warp the original image to match manipulated landmarks
78
+ src_px = face.pixel_coords
79
+ dst_px = manipulated.pixel_coords
80
+ target = warp_image_tps(resized, src_px, dst_px)
81
+
82
+ # Apply clinical augmentation to INPUT only (never target)
83
+ augmented_input = apply_clinical_augmentation(resized, rng=rng)
84
+
85
+ return TrainingPair(
86
+ input_image=augmented_input,
87
+ target_image=target,
88
+ conditioning=landmark_img,
89
+ canny=canny,
90
+ mask=mask,
91
+ procedure=procedure,
92
+ intensity=intensity,
93
+ )
94
+
95
+
96
+ def generate_pairs_from_directory(
97
+ image_dir: str | Path,
98
+ num_pairs: int = 1000,
99
+ target_size: int = 512,
100
+ seed: int = 42,
101
+ ) -> Iterator[TrainingPair]:
102
+ """Generate training pairs from a directory of face images."""
103
+ rng = np.random.default_rng(seed)
104
+ image_dir = Path(image_dir)
105
+
106
+ extensions = {".jpg", ".jpeg", ".png", ".webp"}
107
+ image_files = sorted(
108
+ f for f in image_dir.iterdir()
109
+ if f.suffix.lower() in extensions
110
+ )
111
+
112
+ if not image_files:
113
+ raise FileNotFoundError(f"No images found in {image_dir}")
114
+
115
+ generated = 0
116
+ consecutive_failures = 0
117
+ idx = 0
118
+ while generated < num_pairs:
119
+ # Cycle through images
120
+ img_path = image_files[idx % len(image_files)]
121
+ idx += 1
122
+ image = cv2.imread(str(img_path))
123
+ if image is None:
124
+ consecutive_failures += 1
125
+ if consecutive_failures > len(image_files):
126
+ print(f"Warning: {consecutive_failures} consecutive failures, stopping early")
127
+ break
128
+ continue
129
+
130
+ pair = generate_pair(image, target_size=target_size, rng=rng)
131
+ if pair is not None:
132
+ yield pair
133
+ generated += 1
134
+ consecutive_failures = 0
135
+ else:
136
+ consecutive_failures += 1
137
+ if consecutive_failures > len(image_files):
138
+ print(f"Warning: {consecutive_failures} consecutive failures, stopping early")
139
+ break
140
+
141
+
142
+ def save_pair(pair: TrainingPair, output_dir: Path, index: int) -> None:
143
+ """Save a training pair to disk."""
144
+ output_dir.mkdir(parents=True, exist_ok=True)
145
+ prefix = f"{index:06d}"
146
+
147
+ cv2.imwrite(str(output_dir / f"{prefix}_input.png"), pair.input_image)
148
+ cv2.imwrite(str(output_dir / f"{prefix}_target.png"), pair.target_image)
149
+ cv2.imwrite(str(output_dir / f"{prefix}_conditioning.png"), pair.conditioning)
150
+ cv2.imwrite(str(output_dir / f"{prefix}_canny.png"), pair.canny)
151
+ cv2.imwrite(str(output_dir / f"{prefix}_mask.png"), (pair.mask * 255).astype(np.uint8))