TaliDror commited on
Commit ·
1d0d403
1
Parent(s): 1cc799a
improved UI
Browse files
app.py
CHANGED
|
@@ -333,31 +333,34 @@ def _extract_facenet_logits(img: Image.Image, model) -> torch.Tensor:
|
|
| 333 |
return logits.squeeze(0)
|
| 334 |
|
| 335 |
|
| 336 |
-
def select_best_images(
|
|
|
|
| 337 |
global facenet_model
|
| 338 |
|
| 339 |
-
n = min(n, len(
|
|
|
|
| 340 |
if facenet_model is None:
|
| 341 |
-
return
|
| 342 |
|
| 343 |
embeddings = torch.stack([_extract_facenet_emb(img, facenet_model) for img in images])
|
| 344 |
sim_matrix = F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
|
| 345 |
avg_sims = (sim_matrix.sum(dim=1) - 1) / (len(images) - 1)
|
| 346 |
top_indices = avg_sims.argsort(descending=True)[:n].tolist()
|
| 347 |
print(f"[select_best:pairwise] top {n} indices={top_indices} avg_sims={avg_sims[top_indices].tolist()}")
|
| 348 |
-
return [
|
| 349 |
|
| 350 |
|
| 351 |
-
def select_best_images_combined(
|
|
|
|
| 352 |
global mtcnn_model, facenet_classify_model
|
| 353 |
|
| 354 |
-
n = min(n, len(
|
| 355 |
if mtcnn_model is None or facenet_classify_model is None:
|
| 356 |
print("[select_best:combined] models unavailable, falling back to pairwise")
|
| 357 |
-
return select_best_images(
|
| 358 |
|
| 359 |
scores = []
|
| 360 |
-
for idx, img in enumerate(
|
| 361 |
_, probs = mtcnn_model.detect(img)
|
| 362 |
det_conf = float(probs[0]) if probs is not None and probs[0] is not None else 0.0
|
| 363 |
|
|
@@ -372,7 +375,7 @@ def select_best_images_combined(images: list, n: int) -> list:
|
|
| 372 |
|
| 373 |
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n]
|
| 374 |
print(f"[select_best:combined] top {n} indices={top_indices} scores={[scores[i] for i in top_indices]}")
|
| 375 |
-
return [
|
| 376 |
|
| 377 |
|
| 378 |
SELECTION_METHODS = ["Pairwise similarity", "Detection + Classify confidence"]
|
|
@@ -421,7 +424,7 @@ def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_
|
|
| 421 |
id_emb = speech_z.to(dtype)
|
| 422 |
id_emb_projected = project_face_embs(pipeline, id_emb)
|
| 423 |
|
| 424 |
-
|
| 425 |
for seed in GENERATION_SEEDS:
|
| 426 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 427 |
|
|
@@ -433,11 +436,14 @@ def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_
|
|
| 433 |
generator=generator,
|
| 434 |
).images[0]
|
| 435 |
|
| 436 |
-
|
| 437 |
|
| 438 |
if selection_method == "Detection + Classify confidence":
|
| 439 |
-
|
| 440 |
-
|
|
|
|
|
|
|
|
|
|
| 441 |
|
| 442 |
# ---------------------------------------------------------------------------
|
| 443 |
# Model loading
|
|
|
|
| 333 |
return logits.squeeze(0)
|
| 334 |
|
| 335 |
|
| 336 |
+
def select_best_images(pairs: list, n: int) -> list:
|
| 337 |
+
"""pairs: list of (image, seed). Returns top-n (image, seed) pairs."""
|
| 338 |
global facenet_model
|
| 339 |
|
| 340 |
+
n = min(n, len(pairs))
|
| 341 |
+
images = [p[0] for p in pairs]
|
| 342 |
if facenet_model is None:
|
| 343 |
+
return pairs[:n]
|
| 344 |
|
| 345 |
embeddings = torch.stack([_extract_facenet_emb(img, facenet_model) for img in images])
|
| 346 |
sim_matrix = F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
|
| 347 |
avg_sims = (sim_matrix.sum(dim=1) - 1) / (len(images) - 1)
|
| 348 |
top_indices = avg_sims.argsort(descending=True)[:n].tolist()
|
| 349 |
print(f"[select_best:pairwise] top {n} indices={top_indices} avg_sims={avg_sims[top_indices].tolist()}")
|
| 350 |
+
return [pairs[i] for i in top_indices]
|
| 351 |
|
| 352 |
|
| 353 |
+
def select_best_images_combined(pairs: list, n: int) -> list:
|
| 354 |
+
"""pairs: list of (image, seed). Returns top-n (image, seed) pairs."""
|
| 355 |
global mtcnn_model, facenet_classify_model
|
| 356 |
|
| 357 |
+
n = min(n, len(pairs))
|
| 358 |
if mtcnn_model is None or facenet_classify_model is None:
|
| 359 |
print("[select_best:combined] models unavailable, falling back to pairwise")
|
| 360 |
+
return select_best_images(pairs, n)
|
| 361 |
|
| 362 |
scores = []
|
| 363 |
+
for idx, (img, _) in enumerate(pairs):
|
| 364 |
_, probs = mtcnn_model.detect(img)
|
| 365 |
det_conf = float(probs[0]) if probs is not None and probs[0] is not None else 0.0
|
| 366 |
|
|
|
|
| 375 |
|
| 376 |
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n]
|
| 377 |
print(f"[select_best:combined] top {n} indices={top_indices} scores={[scores[i] for i in top_indices]}")
|
| 378 |
+
return [pairs[i] for i in top_indices]
|
| 379 |
|
| 380 |
|
| 381 |
SELECTION_METHODS = ["Pairwise similarity", "Detection + Classify confidence"]
|
|
|
|
| 424 |
id_emb = speech_z.to(dtype)
|
| 425 |
id_emb_projected = project_face_embs(pipeline, id_emb)
|
| 426 |
|
| 427 |
+
pairs = []
|
| 428 |
for seed in GENERATION_SEEDS:
|
| 429 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 430 |
|
|
|
|
| 436 |
generator=generator,
|
| 437 |
).images[0]
|
| 438 |
|
| 439 |
+
pairs.append((img, seed))
|
| 440 |
|
| 441 |
if selection_method == "Detection + Classify confidence":
|
| 442 |
+
best = select_best_images_combined(pairs, int(num_display))
|
| 443 |
+
else:
|
| 444 |
+
best = select_best_images(pairs, int(num_display))
|
| 445 |
+
|
| 446 |
+
return [(img, f"Seed: {seed}") for img, seed in best], ""
|
| 447 |
|
| 448 |
# ---------------------------------------------------------------------------
|
| 449 |
# Model loading
|