dreamlessx commited on
Commit
a1b1648
·
verified ·
1 Parent(s): 0d8a73f

Upload landmarkdiff/ensemble.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/ensemble.py +311 -0
landmarkdiff/ensemble.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ensemble inference for improved output quality.
2
+
3
+ Generates multiple outputs with different random seeds and combines them
4
+ to reduce per-sample variance. Supports multiple aggregation strategies:
5
+ - Pixel-space averaging (fast, slight blur)
6
+ - Feature-space averaging (better quality, requires VAE encode)
7
+ - Best-of-N selection (picks output with highest identity similarity)
8
+
9
+ Usage:
10
+ from landmarkdiff.ensemble import EnsembleInference
11
+
12
+ ensemble = EnsembleInference(
13
+ mode="controlnet",
14
+ controlnet_checkpoint="checkpoints/final/controlnet_ema",
15
+ n_samples=5,
16
+ strategy="best_of_n",
17
+ )
18
+ ensemble.load()
19
+ result = ensemble.generate(image, procedure="rhinoplasty", intensity=65)
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from typing import Optional
25
+
26
+ import cv2
27
+ import numpy as np
28
+
29
+
30
+ class EnsembleInference:
31
+ """Multi-sample ensemble inference for LandmarkDiff.
32
+
33
+ Generates N outputs with different seeds and combines them using
34
+ the specified aggregation strategy.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ mode: str = "controlnet",
40
+ controlnet_checkpoint: str | None = None,
41
+ displacement_model_path: str | None = None,
42
+ n_samples: int = 5,
43
+ strategy: str = "best_of_n",
44
+ base_seed: int = 42,
45
+ **pipeline_kwargs,
46
+ ):
47
+ """Initialize ensemble inference.
48
+
49
+ Args:
50
+ mode: Pipeline mode (controlnet, img2img, tps).
51
+ controlnet_checkpoint: Path to fine-tuned ControlNet.
52
+ displacement_model_path: Path to displacement model.
53
+ n_samples: Number of ensemble members.
54
+ strategy: Aggregation strategy:
55
+ - "pixel_average": Average in pixel space.
56
+ - "weighted_average": Weighted by quality metrics.
57
+ - "best_of_n": Select best by identity similarity.
58
+ - "median": Pixel-wise median (robust to outliers).
59
+ base_seed: Base random seed (each sample uses base_seed + i).
60
+ **pipeline_kwargs: Additional kwargs for LandmarkDiffPipeline.
61
+ """
62
+ self.mode = mode
63
+ self.controlnet_checkpoint = controlnet_checkpoint
64
+ self.displacement_model_path = displacement_model_path
65
+ self.n_samples = n_samples
66
+ self.strategy = strategy
67
+ self.base_seed = base_seed
68
+ self.pipeline_kwargs = pipeline_kwargs
69
+ self._pipeline = None
70
+
71
+ def load(self) -> None:
72
+ """Load the inference pipeline."""
73
+ from landmarkdiff.inference import LandmarkDiffPipeline
74
+
75
+ self._pipeline = LandmarkDiffPipeline(
76
+ mode=self.mode,
77
+ controlnet_checkpoint=self.controlnet_checkpoint,
78
+ displacement_model_path=self.displacement_model_path,
79
+ **self.pipeline_kwargs,
80
+ )
81
+ self._pipeline.load()
82
+
83
+ @property
84
+ def is_loaded(self) -> bool:
85
+ return self._pipeline is not None and self._pipeline.is_loaded
86
+
87
+ def generate(
88
+ self,
89
+ image: np.ndarray,
90
+ procedure: str = "rhinoplasty",
91
+ intensity: float = 50.0,
92
+ num_inference_steps: int = 30,
93
+ guidance_scale: float = 9.0,
94
+ controlnet_conditioning_scale: float = 0.9,
95
+ strength: float = 0.5,
96
+ seed: Optional[int] = None,
97
+ **kwargs,
98
+ ) -> dict:
99
+ """Generate ensemble output.
100
+
101
+ Returns:
102
+ Dict with keys:
103
+ - output: Final ensembled image (np.ndarray, BGR, uint8)
104
+ - outputs: List of all individual outputs
105
+ - scores: Quality scores for each sample
106
+ - selected_idx: Index of selected sample (for best_of_n)
107
+ - strategy: Aggregation strategy used
108
+ - n_samples: Number of ensemble members
109
+ """
110
+ if not self.is_loaded:
111
+ raise RuntimeError("Pipeline not loaded. Call load() first.")
112
+
113
+ base = seed if seed is not None else self.base_seed
114
+ outputs = []
115
+ results = []
116
+
117
+ # Generate N samples
118
+ for i in range(self.n_samples):
119
+ sample_seed = base + i
120
+ result = self._pipeline.generate(
121
+ image,
122
+ procedure=procedure,
123
+ intensity=intensity,
124
+ num_inference_steps=num_inference_steps,
125
+ guidance_scale=guidance_scale,
126
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
127
+ strength=strength,
128
+ seed=sample_seed,
129
+ **kwargs,
130
+ )
131
+ outputs.append(result["output"])
132
+ results.append(result)
133
+
134
+ # Aggregate
135
+ if self.strategy == "pixel_average":
136
+ final = self._pixel_average(outputs)
137
+ scores = [1.0 / self.n_samples] * self.n_samples
138
+ selected_idx = -1
139
+
140
+ elif self.strategy == "weighted_average":
141
+ final, scores = self._weighted_average(outputs, image)
142
+ selected_idx = -1
143
+
144
+ elif self.strategy == "best_of_n":
145
+ final, scores, selected_idx = self._best_of_n(outputs, image)
146
+
147
+ elif self.strategy == "median":
148
+ final = self._pixel_median(outputs)
149
+ scores = [1.0 / self.n_samples] * self.n_samples
150
+ selected_idx = -1
151
+
152
+ else:
153
+ raise ValueError(f"Unknown strategy: {self.strategy}")
154
+
155
+ # Copy metadata from best result
156
+ best_idx = selected_idx if selected_idx >= 0 else 0
157
+ ensemble_result = dict(results[best_idx])
158
+ ensemble_result.update({
159
+ "output": final,
160
+ "outputs": outputs,
161
+ "scores": scores,
162
+ "selected_idx": selected_idx,
163
+ "strategy": self.strategy,
164
+ "n_samples": self.n_samples,
165
+ })
166
+
167
+ return ensemble_result
168
+
169
+ def _pixel_average(self, outputs: list[np.ndarray]) -> np.ndarray:
170
+ """Simple pixel-space averaging."""
171
+ stacked = np.stack(outputs, axis=0).astype(np.float32)
172
+ return np.clip(stacked.mean(axis=0), 0, 255).astype(np.uint8)
173
+
174
+ def _pixel_median(self, outputs: list[np.ndarray]) -> np.ndarray:
175
+ """Pixel-wise median (robust to outliers)."""
176
+ stacked = np.stack(outputs, axis=0)
177
+ return np.median(stacked, axis=0).astype(np.uint8)
178
+
179
+ def _weighted_average(
180
+ self,
181
+ outputs: list[np.ndarray],
182
+ reference: np.ndarray,
183
+ ) -> tuple[np.ndarray, list[float]]:
184
+ """Quality-weighted averaging using SSIM as weight."""
185
+ from landmarkdiff.evaluation import compute_ssim
186
+
187
+ # Compute SSIM of each output to reference
188
+ scores = []
189
+ for output in outputs:
190
+ ssim = compute_ssim(output, reference)
191
+ scores.append(float(ssim))
192
+
193
+ # Normalize to weights (higher SSIM = higher weight)
194
+ total = sum(scores) or 1.0
195
+ weights = [s / total for s in scores]
196
+
197
+ # Weighted average
198
+ result = np.zeros_like(outputs[0], dtype=np.float32)
199
+ for output, weight in zip(outputs, weights):
200
+ result += output.astype(np.float32) * weight
201
+
202
+ return np.clip(result, 0, 255).astype(np.uint8), scores
203
+
204
+ def _best_of_n(
205
+ self,
206
+ outputs: list[np.ndarray],
207
+ reference: np.ndarray,
208
+ ) -> tuple[np.ndarray, list[float], int]:
209
+ """Select the output with highest identity similarity to reference."""
210
+ from landmarkdiff.evaluation import compute_identity_similarity
211
+
212
+ scores = []
213
+ for output in outputs:
214
+ sim = compute_identity_similarity(output, reference)
215
+ scores.append(float(sim))
216
+
217
+ best_idx = int(np.argmax(scores))
218
+ return outputs[best_idx], scores, best_idx
219
+
220
+
221
+ def ensemble_inference(
222
+ image_path: str,
223
+ procedure: str = "rhinoplasty",
224
+ intensity: float = 65.0,
225
+ output_dir: str = "ensemble_output",
226
+ n_samples: int = 5,
227
+ strategy: str = "best_of_n",
228
+ mode: str = "tps",
229
+ controlnet_checkpoint: str | None = None,
230
+ displacement_model_path: str | None = None,
231
+ seed: int = 42,
232
+ ) -> None:
233
+ """CLI entry point for ensemble inference."""
234
+ from pathlib import Path
235
+
236
+ out = Path(output_dir)
237
+ out.mkdir(parents=True, exist_ok=True)
238
+
239
+ image = cv2.imread(image_path)
240
+ if image is None:
241
+ print(f"ERROR: Cannot read image: {image_path}")
242
+ return
243
+
244
+ image = cv2.resize(image, (512, 512))
245
+
246
+ ensemble = EnsembleInference(
247
+ mode=mode,
248
+ controlnet_checkpoint=controlnet_checkpoint,
249
+ displacement_model_path=displacement_model_path,
250
+ n_samples=n_samples,
251
+ strategy=strategy,
252
+ base_seed=seed,
253
+ )
254
+ ensemble.load()
255
+
256
+ print(f"Generating ensemble ({n_samples} samples, strategy={strategy})...")
257
+ result = ensemble.generate(
258
+ image,
259
+ procedure=procedure,
260
+ intensity=intensity,
261
+ seed=seed,
262
+ )
263
+
264
+ # Save outputs
265
+ cv2.imwrite(str(out / "ensemble_output.png"), result["output"])
266
+ cv2.imwrite(str(out / "original.png"), image)
267
+
268
+ # Save individual samples
269
+ for i, output in enumerate(result["outputs"]):
270
+ cv2.imwrite(str(out / f"sample_{i:02d}.png"), output)
271
+ score = result["scores"][i]
272
+ print(f" Sample {i}: score={score:.4f}"
273
+ + (" <-- selected" if i == result.get("selected_idx") else ""))
274
+
275
+ # Comparison grid
276
+ panels = [image] + result["outputs"] + [result["output"]]
277
+ # Resize to 256 for compact grid
278
+ panels_small = [cv2.resize(p, (256, 256)) for p in panels]
279
+ grid = np.hstack(panels_small)
280
+ cv2.imwrite(str(out / "comparison_grid.png"), grid)
281
+
282
+ print(f"\nEnsemble output saved: {out / 'ensemble_output.png'}")
283
+ if result.get("selected_idx", -1) >= 0:
284
+ print(f"Selected sample: {result['selected_idx']} "
285
+ f"(score={result['scores'][result['selected_idx']]:.4f})")
286
+
287
+
288
+ if __name__ == "__main__":
289
+ import argparse
290
+
291
+ parser = argparse.ArgumentParser(description="Ensemble inference")
292
+ parser.add_argument("image", help="Input face image")
293
+ parser.add_argument("--procedure", default="rhinoplasty")
294
+ parser.add_argument("--intensity", type=float, default=65.0)
295
+ parser.add_argument("--output", default="ensemble_output")
296
+ parser.add_argument("--n_samples", type=int, default=5)
297
+ parser.add_argument("--strategy", default="best_of_n",
298
+ choices=["pixel_average", "weighted_average", "best_of_n", "median"])
299
+ parser.add_argument("--mode", default="tps",
300
+ choices=["controlnet", "img2img", "tps"])
301
+ parser.add_argument("--checkpoint", default=None)
302
+ parser.add_argument("--displacement-model", default=None)
303
+ parser.add_argument("--seed", type=int, default=42)
304
+ args = parser.parse_args()
305
+
306
+ ensemble_inference(
307
+ args.image, args.procedure, args.intensity,
308
+ args.output, args.n_samples, args.strategy,
309
+ args.mode, args.checkpoint, args.displacement_model,
310
+ args.seed,
311
+ )