TaliDror commited on
Commit
1b85d17
·
1 Parent(s): b4f2d1d

added selection method

Browse files
Files changed (1) hide show
  1. app.py +50 -10
app.py CHANGED
@@ -37,6 +37,7 @@ pipeline = None
37
  speaker_encoder = None
38
  facenet_model = None
39
  facenet_classify_model = None
 
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
 
42
 
@@ -347,13 +348,43 @@ def select_best_images(images: list, n: int) -> list:
347
  return [images[i] for i in top_indices]
348
 
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  # ---------------------------------------------------------------------------
351
  # Generation
352
  # ---------------------------------------------------------------------------
353
  INTERNAL_SAMPLES = 10
354
 
355
  @spaces.GPU(duration=120)
356
- def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_seed):
357
  global pipeline, speaker_encoder, facenet_model, device
358
 
359
  if audio_path is None:
@@ -402,6 +433,8 @@ def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_
402
 
403
  images.append(img)
404
 
 
 
405
  return select_best_images(images, int(num_display)), ""
406
 
407
  # ---------------------------------------------------------------------------
@@ -409,7 +442,7 @@ def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_
409
  # ---------------------------------------------------------------------------
410
 
411
  def load_models():
412
- global pipeline, speaker_encoder, facenet_model, facenet_classify_model, device
413
  dtype = torch.float16 if device == "cuda" else torch.float32
414
 
415
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -458,17 +491,19 @@ def load_models():
458
  pipeline = pipeline.to(device)
459
  print(" Pipeline ready")
460
 
461
- # FaceNet for best-sample selection
462
- print("Loading FaceNet for best-sample selection...")
463
  try:
464
- from facenet_pytorch import InceptionResnetV1
465
  facenet_model = InceptionResnetV1(pretrained='vggface2', classify=False).eval()
466
  facenet_classify_model = InceptionResnetV1(pretrained='vggface2', classify=True).eval()
467
- print(" FaceNet ready")
 
468
  except Exception as e:
469
- print(f" FaceNet unavailable ({e}); select-best will fall back to first image")
470
  facenet_model = None
471
  facenet_classify_model = None
 
472
 
473
 
474
  # ---------------------------------------------------------------------------
@@ -491,20 +526,25 @@ def build_demo():
491
  guidance_scale = gr.Slider(1.0, 10.0, value=2.5, step=0.5, label="Guidance scale")
492
  num_steps = gr.Slider(10, 50, value=25, step=5, label="Inference steps")
493
  base_seed = gr.Slider(0, 9999, value=42, step=1, label="Base seed")
 
 
 
 
 
494
  generate_btn = gr.Button("Generate", variant="primary")
495
 
496
  with gr.Column():
497
  gallery = gr.Gallery(label="Generated Images")
498
  status = gr.Markdown(visible=False)
499
 
500
- def _generate(audio, n, gs, steps, seed):
501
- imgs, msg = generate(audio, n, gs, steps, seed)
502
  visible = bool(msg)
503
  return imgs, gr.update(value=msg, visible=visible)
504
 
505
  generate_btn.click(
506
  fn=_generate,
507
- inputs=[audio_input, num_display, guidance_scale, num_steps, base_seed],
508
  outputs=[gallery, status],
509
  )
510
 
 
37
  speaker_encoder = None
38
  facenet_model = None
39
  facenet_classify_model = None
40
+ mtcnn_model = None
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
 
43
 
 
348
  return [images[i] for i in top_indices]
349
 
350
 
351
+ def select_best_images_combined(images: list, n: int) -> list:
352
+ global mtcnn_model, facenet_classify_model
353
+
354
+ n = min(n, len(images))
355
+ if mtcnn_model is None or facenet_classify_model is None:
356
+ print("[select_best:combined] models unavailable, falling back to pairwise")
357
+ return select_best_images(images, n)
358
+
359
+ scores = []
360
+ for idx, img in enumerate(images):
361
+ _, probs = mtcnn_model.detect(img)
362
+ det_conf = float(probs[0]) if probs is not None and probs[0] is not None else 0.0
363
+
364
+ tensor = _facenet_transform()(img.convert("RGB")).unsqueeze(0)
365
+ with torch.no_grad():
366
+ logits = facenet_classify_model(tensor)
367
+ classify_conf = float(F.softmax(logits, dim=1).max(dim=1).values[0])
368
+
369
+ combined = det_conf * classify_conf
370
+ scores.append(combined)
371
+ print(f" [combined] idx={idx} det={det_conf:.3f} classify={classify_conf:.3f} combined={combined:.3f}")
372
+
373
+ top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n]
374
+ print(f"[select_best:combined] top {n} indices={top_indices} scores={[scores[i] for i in top_indices]}")
375
+ return [images[i] for i in top_indices]
376
+
377
+
378
+ SELECTION_METHODS = ["Pairwise similarity", "Detection + Classify confidence"]
379
+
380
+
381
  # ---------------------------------------------------------------------------
382
  # Generation
383
  # ---------------------------------------------------------------------------
384
  INTERNAL_SAMPLES = 10
385
 
386
  @spaces.GPU(duration=120)
387
+ def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_seed, selection_method="Pairwise similarity"):
388
  global pipeline, speaker_encoder, facenet_model, device
389
 
390
  if audio_path is None:
 
433
 
434
  images.append(img)
435
 
436
+ if selection_method == "Detection + Classify confidence":
437
+ return select_best_images_combined(images, int(num_display)), ""
438
  return select_best_images(images, int(num_display)), ""
439
 
440
  # ---------------------------------------------------------------------------
 
442
  # ---------------------------------------------------------------------------
443
 
444
  def load_models():
445
+ global pipeline, speaker_encoder, facenet_model, facenet_classify_model, mtcnn_model, device
446
  dtype = torch.float16 if device == "cuda" else torch.float32
447
 
448
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
491
  pipeline = pipeline.to(device)
492
  print(" Pipeline ready")
493
 
494
+ # FaceNet + MTCNN for best-sample selection
495
+ print("Loading FaceNet + MTCNN for best-sample selection...")
496
  try:
497
+ from facenet_pytorch import InceptionResnetV1, MTCNN
498
  facenet_model = InceptionResnetV1(pretrained='vggface2', classify=False).eval()
499
  facenet_classify_model = InceptionResnetV1(pretrained='vggface2', classify=True).eval()
500
+ mtcnn_model = MTCNN(keep_all=False, device='cpu')
501
+ print(" FaceNet + MTCNN ready")
502
  except Exception as e:
503
+ print(f" FaceNet/MTCNN unavailable ({e}); select-best will fall back to first image")
504
  facenet_model = None
505
  facenet_classify_model = None
506
+ mtcnn_model = None
507
 
508
 
509
  # ---------------------------------------------------------------------------
 
526
  guidance_scale = gr.Slider(1.0, 10.0, value=2.5, step=0.5, label="Guidance scale")
527
  num_steps = gr.Slider(10, 50, value=25, step=5, label="Inference steps")
528
  base_seed = gr.Slider(0, 9999, value=42, step=1, label="Base seed")
529
+ selection_method = gr.Radio(
530
+ choices=SELECTION_METHODS,
531
+ value=SELECTION_METHODS[1],
532
+ label="Best-image selection method",
533
+ )
534
  generate_btn = gr.Button("Generate", variant="primary")
535
 
536
  with gr.Column():
537
  gallery = gr.Gallery(label="Generated Images")
538
  status = gr.Markdown(visible=False)
539
 
540
+ def _generate(audio, n, gs, steps, seed, sel_method):
541
+ imgs, msg = generate(audio, n, gs, steps, seed, sel_method)
542
  visible = bool(msg)
543
  return imgs, gr.update(value=msg, visible=visible)
544
 
545
  generate_btn.click(
546
  fn=_generate,
547
+ inputs=[audio_input, num_display, guidance_scale, num_steps, base_seed, selection_method],
548
  outputs=[gallery, status],
549
  )
550