TaliDror commited on
Commit
1d0d403
·
1 Parent(s): 1cc799a

improved UI

Browse files
Files changed (1) hide show
  1. app.py +19 -13
app.py CHANGED
@@ -333,31 +333,34 @@ def _extract_facenet_logits(img: Image.Image, model) -> torch.Tensor:
333
  return logits.squeeze(0)
334
 
335
 
336
- def select_best_images(images: list, n: int) -> list:
 
337
  global facenet_model
338
 
339
- n = min(n, len(images))
 
340
  if facenet_model is None:
341
- return images[:n]
342
 
343
  embeddings = torch.stack([_extract_facenet_emb(img, facenet_model) for img in images])
344
  sim_matrix = F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
345
  avg_sims = (sim_matrix.sum(dim=1) - 1) / (len(images) - 1)
346
  top_indices = avg_sims.argsort(descending=True)[:n].tolist()
347
  print(f"[select_best:pairwise] top {n} indices={top_indices} avg_sims={avg_sims[top_indices].tolist()}")
348
- return [images[i] for i in top_indices]
349
 
350
 
351
- def select_best_images_combined(images: list, n: int) -> list:
 
352
  global mtcnn_model, facenet_classify_model
353
 
354
- n = min(n, len(images))
355
  if mtcnn_model is None or facenet_classify_model is None:
356
  print("[select_best:combined] models unavailable, falling back to pairwise")
357
- return select_best_images(images, n)
358
 
359
  scores = []
360
- for idx, img in enumerate(images):
361
  _, probs = mtcnn_model.detect(img)
362
  det_conf = float(probs[0]) if probs is not None and probs[0] is not None else 0.0
363
 
@@ -372,7 +375,7 @@ def select_best_images_combined(images: list, n: int) -> list:
372
 
373
  top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n]
374
  print(f"[select_best:combined] top {n} indices={top_indices} scores={[scores[i] for i in top_indices]}")
375
- return [images[i] for i in top_indices]
376
 
377
 
378
  SELECTION_METHODS = ["Pairwise similarity", "Detection + Classify confidence"]
@@ -421,7 +424,7 @@ def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_
421
  id_emb = speech_z.to(dtype)
422
  id_emb_projected = project_face_embs(pipeline, id_emb)
423
 
424
- images = []
425
  for seed in GENERATION_SEEDS:
426
  generator = torch.Generator(device=device).manual_seed(seed)
427
 
@@ -433,11 +436,14 @@ def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_
433
  generator=generator,
434
  ).images[0]
435
 
436
- images.append(img)
437
 
438
  if selection_method == "Detection + Classify confidence":
439
- return select_best_images_combined(images, int(num_display)), ""
440
- return select_best_images(images, int(num_display)), ""
 
 
 
441
 
442
  # ---------------------------------------------------------------------------
443
  # Model loading
 
333
  return logits.squeeze(0)
334
 
335
 
336
+ def select_best_images(pairs: list, n: int) -> list:
337
+ """pairs: list of (image, seed). Returns top-n (image, seed) pairs."""
338
  global facenet_model
339
 
340
+ n = min(n, len(pairs))
341
+ images = [p[0] for p in pairs]
342
  if facenet_model is None:
343
+ return pairs[:n]
344
 
345
  embeddings = torch.stack([_extract_facenet_emb(img, facenet_model) for img in images])
346
  sim_matrix = F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
347
  avg_sims = (sim_matrix.sum(dim=1) - 1) / (len(images) - 1)
348
  top_indices = avg_sims.argsort(descending=True)[:n].tolist()
349
  print(f"[select_best:pairwise] top {n} indices={top_indices} avg_sims={avg_sims[top_indices].tolist()}")
350
+ return [pairs[i] for i in top_indices]
351
 
352
 
353
+ def select_best_images_combined(pairs: list, n: int) -> list:
354
+ """pairs: list of (image, seed). Returns top-n (image, seed) pairs."""
355
  global mtcnn_model, facenet_classify_model
356
 
357
+ n = min(n, len(pairs))
358
  if mtcnn_model is None or facenet_classify_model is None:
359
  print("[select_best:combined] models unavailable, falling back to pairwise")
360
+ return select_best_images(pairs, n)
361
 
362
  scores = []
363
+ for idx, (img, _) in enumerate(pairs):
364
  _, probs = mtcnn_model.detect(img)
365
  det_conf = float(probs[0]) if probs is not None and probs[0] is not None else 0.0
366
 
 
375
 
376
  top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n]
377
  print(f"[select_best:combined] top {n} indices={top_indices} scores={[scores[i] for i in top_indices]}")
378
+ return [pairs[i] for i in top_indices]
379
 
380
 
381
  SELECTION_METHODS = ["Pairwise similarity", "Detection + Classify confidence"]
 
424
  id_emb = speech_z.to(dtype)
425
  id_emb_projected = project_face_embs(pipeline, id_emb)
426
 
427
+ pairs = []
428
  for seed in GENERATION_SEEDS:
429
  generator = torch.Generator(device=device).manual_seed(seed)
430
 
 
436
  generator=generator,
437
  ).images[0]
438
 
439
+ pairs.append((img, seed))
440
 
441
  if selection_method == "Detection + Classify confidence":
442
+ best = select_best_images_combined(pairs, int(num_display))
443
+ else:
444
+ best = select_best_images(pairs, int(num_display))
445
+
446
+ return [(img, f"Seed: {seed}") for img, seed in best], ""
447
 
448
  # ---------------------------------------------------------------------------
449
  # Model loading