TaliDror commited on
Commit ·
1b85d17
1
Parent(s): b4f2d1d
added selection method
Browse files
app.py
CHANGED
|
@@ -37,6 +37,7 @@ pipeline = None
|
|
| 37 |
speaker_encoder = None
|
| 38 |
facenet_model = None
|
| 39 |
facenet_classify_model = None
|
|
|
|
| 40 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 41 |
|
| 42 |
|
|
@@ -347,13 +348,43 @@ def select_best_images(images: list, n: int) -> list:
|
|
| 347 |
return [images[i] for i in top_indices]
|
| 348 |
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
# ---------------------------------------------------------------------------
|
| 351 |
# Generation
|
| 352 |
# ---------------------------------------------------------------------------
|
| 353 |
INTERNAL_SAMPLES = 10
|
| 354 |
|
| 355 |
@spaces.GPU(duration=120)
|
| 356 |
-
def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_seed):
|
| 357 |
global pipeline, speaker_encoder, facenet_model, device
|
| 358 |
|
| 359 |
if audio_path is None:
|
|
@@ -402,6 +433,8 @@ def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_
|
|
| 402 |
|
| 403 |
images.append(img)
|
| 404 |
|
|
|
|
|
|
|
| 405 |
return select_best_images(images, int(num_display)), ""
|
| 406 |
|
| 407 |
# ---------------------------------------------------------------------------
|
|
@@ -409,7 +442,7 @@ def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_
|
|
| 409 |
# ---------------------------------------------------------------------------
|
| 410 |
|
| 411 |
def load_models():
|
| 412 |
-
global pipeline, speaker_encoder, facenet_model, facenet_classify_model, device
|
| 413 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 414 |
|
| 415 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -458,17 +491,19 @@ def load_models():
|
|
| 458 |
pipeline = pipeline.to(device)
|
| 459 |
print(" Pipeline ready")
|
| 460 |
|
| 461 |
-
# FaceNet for best-sample selection
|
| 462 |
-
print("Loading FaceNet for best-sample selection...")
|
| 463 |
try:
|
| 464 |
-
from facenet_pytorch import InceptionResnetV1
|
| 465 |
facenet_model = InceptionResnetV1(pretrained='vggface2', classify=False).eval()
|
| 466 |
facenet_classify_model = InceptionResnetV1(pretrained='vggface2', classify=True).eval()
|
| 467 |
-
|
|
|
|
| 468 |
except Exception as e:
|
| 469 |
-
print(f" FaceNet unavailable ({e}); select-best will fall back to first image")
|
| 470 |
facenet_model = None
|
| 471 |
facenet_classify_model = None
|
|
|
|
| 472 |
|
| 473 |
|
| 474 |
# ---------------------------------------------------------------------------
|
|
@@ -491,20 +526,25 @@ def build_demo():
|
|
| 491 |
guidance_scale = gr.Slider(1.0, 10.0, value=2.5, step=0.5, label="Guidance scale")
|
| 492 |
num_steps = gr.Slider(10, 50, value=25, step=5, label="Inference steps")
|
| 493 |
base_seed = gr.Slider(0, 9999, value=42, step=1, label="Base seed")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
generate_btn = gr.Button("Generate", variant="primary")
|
| 495 |
|
| 496 |
with gr.Column():
|
| 497 |
gallery = gr.Gallery(label="Generated Images")
|
| 498 |
status = gr.Markdown(visible=False)
|
| 499 |
|
| 500 |
-
def _generate(audio, n, gs, steps, seed):
|
| 501 |
-
imgs, msg = generate(audio, n, gs, steps, seed)
|
| 502 |
visible = bool(msg)
|
| 503 |
return imgs, gr.update(value=msg, visible=visible)
|
| 504 |
|
| 505 |
generate_btn.click(
|
| 506 |
fn=_generate,
|
| 507 |
-
inputs=[audio_input, num_display, guidance_scale, num_steps, base_seed],
|
| 508 |
outputs=[gallery, status],
|
| 509 |
)
|
| 510 |
|
|
|
|
| 37 |
speaker_encoder = None
|
| 38 |
facenet_model = None
|
| 39 |
facenet_classify_model = None
|
| 40 |
+
mtcnn_model = None
|
| 41 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 42 |
|
| 43 |
|
|
|
|
| 348 |
return [images[i] for i in top_indices]
|
| 349 |
|
| 350 |
|
| 351 |
+
def select_best_images_combined(images: list, n: int) -> list:
|
| 352 |
+
global mtcnn_model, facenet_classify_model
|
| 353 |
+
|
| 354 |
+
n = min(n, len(images))
|
| 355 |
+
if mtcnn_model is None or facenet_classify_model is None:
|
| 356 |
+
print("[select_best:combined] models unavailable, falling back to pairwise")
|
| 357 |
+
return select_best_images(images, n)
|
| 358 |
+
|
| 359 |
+
scores = []
|
| 360 |
+
for idx, img in enumerate(images):
|
| 361 |
+
_, probs = mtcnn_model.detect(img)
|
| 362 |
+
det_conf = float(probs[0]) if probs is not None and probs[0] is not None else 0.0
|
| 363 |
+
|
| 364 |
+
tensor = _facenet_transform()(img.convert("RGB")).unsqueeze(0)
|
| 365 |
+
with torch.no_grad():
|
| 366 |
+
logits = facenet_classify_model(tensor)
|
| 367 |
+
classify_conf = float(F.softmax(logits, dim=1).max(dim=1).values[0])
|
| 368 |
+
|
| 369 |
+
combined = det_conf * classify_conf
|
| 370 |
+
scores.append(combined)
|
| 371 |
+
print(f" [combined] idx={idx} det={det_conf:.3f} classify={classify_conf:.3f} combined={combined:.3f}")
|
| 372 |
+
|
| 373 |
+
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n]
|
| 374 |
+
print(f"[select_best:combined] top {n} indices={top_indices} scores={[scores[i] for i in top_indices]}")
|
| 375 |
+
return [images[i] for i in top_indices]
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
SELECTION_METHODS = ["Pairwise similarity", "Detection + Classify confidence"]
|
| 379 |
+
|
| 380 |
+
|
| 381 |
# ---------------------------------------------------------------------------
|
| 382 |
# Generation
|
| 383 |
# ---------------------------------------------------------------------------
|
| 384 |
INTERNAL_SAMPLES = 10
|
| 385 |
|
| 386 |
@spaces.GPU(duration=120)
|
| 387 |
+
def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_seed, selection_method="Pairwise similarity"):
|
| 388 |
global pipeline, speaker_encoder, facenet_model, device
|
| 389 |
|
| 390 |
if audio_path is None:
|
|
|
|
| 433 |
|
| 434 |
images.append(img)
|
| 435 |
|
| 436 |
+
if selection_method == "Detection + Classify confidence":
|
| 437 |
+
return select_best_images_combined(images, int(num_display)), ""
|
| 438 |
return select_best_images(images, int(num_display)), ""
|
| 439 |
|
| 440 |
# ---------------------------------------------------------------------------
|
|
|
|
| 442 |
# ---------------------------------------------------------------------------
|
| 443 |
|
| 444 |
def load_models():
|
| 445 |
+
global pipeline, speaker_encoder, facenet_model, facenet_classify_model, mtcnn_model, device
|
| 446 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 447 |
|
| 448 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 491 |
pipeline = pipeline.to(device)
|
| 492 |
print(" Pipeline ready")
|
| 493 |
|
| 494 |
+
# FaceNet + MTCNN for best-sample selection
|
| 495 |
+
print("Loading FaceNet + MTCNN for best-sample selection...")
|
| 496 |
try:
|
| 497 |
+
from facenet_pytorch import InceptionResnetV1, MTCNN
|
| 498 |
facenet_model = InceptionResnetV1(pretrained='vggface2', classify=False).eval()
|
| 499 |
facenet_classify_model = InceptionResnetV1(pretrained='vggface2', classify=True).eval()
|
| 500 |
+
mtcnn_model = MTCNN(keep_all=False, device='cpu')
|
| 501 |
+
print(" FaceNet + MTCNN ready")
|
| 502 |
except Exception as e:
|
| 503 |
+
print(f" FaceNet/MTCNN unavailable ({e}); select-best will fall back to first image")
|
| 504 |
facenet_model = None
|
| 505 |
facenet_classify_model = None
|
| 506 |
+
mtcnn_model = None
|
| 507 |
|
| 508 |
|
| 509 |
# ---------------------------------------------------------------------------
|
|
|
|
| 526 |
guidance_scale = gr.Slider(1.0, 10.0, value=2.5, step=0.5, label="Guidance scale")
|
| 527 |
num_steps = gr.Slider(10, 50, value=25, step=5, label="Inference steps")
|
| 528 |
base_seed = gr.Slider(0, 9999, value=42, step=1, label="Base seed")
|
| 529 |
+
selection_method = gr.Radio(
|
| 530 |
+
choices=SELECTION_METHODS,
|
| 531 |
+
value=SELECTION_METHODS[1],
|
| 532 |
+
label="Best-image selection method",
|
| 533 |
+
)
|
| 534 |
generate_btn = gr.Button("Generate", variant="primary")
|
| 535 |
|
| 536 |
with gr.Column():
|
| 537 |
gallery = gr.Gallery(label="Generated Images")
|
| 538 |
status = gr.Markdown(visible=False)
|
| 539 |
|
| 540 |
+
def _generate(audio, n, gs, steps, seed, sel_method):
|
| 541 |
+
imgs, msg = generate(audio, n, gs, steps, seed, sel_method)
|
| 542 |
visible = bool(msg)
|
| 543 |
return imgs, gr.update(value=msg, visible=visible)
|
| 544 |
|
| 545 |
generate_btn.click(
|
| 546 |
fn=_generate,
|
| 547 |
+
inputs=[audio_input, num_display, guidance_scale, num_steps, base_seed, selection_method],
|
| 548 |
outputs=[gallery, status],
|
| 549 |
)
|
| 550 |
|