TaliDror commited on
Commit
39db2c4
·
1 Parent(s): e66529d

adaptation to enable ZeroGPU

Browse files
Files changed (1) hide show
  1. app.py +40 -18
app.py CHANGED
@@ -362,46 +362,68 @@ def select_best_image(images: list, method: str) -> Image.Image:
362
  def generate(audio_path, num_samples, guidance_scale, num_inference_steps, base_seed, select_best, best_selection="pairwise"):
363
  global pipeline, speaker_encoder, facenet_model, facenet_classify_model, device
364
 
365
- if pipeline is None:
366
- return None, "Model not loaded. Check Space configuration."
367
  if audio_path is None:
368
  return None, "Please provide an audio file."
369
 
 
 
 
 
 
 
 
 
 
 
 
370
  try:
371
  waveform = load_and_process_audio(audio_path, device, max_seconds=5.0)
372
  except Exception as e:
373
  return None, f"Audio loading failed: {e}"
374
 
 
 
375
  with torch.no_grad():
376
- speech_z = speaker_encoder(waveform, normalize=True, apply_shared_projection=False)
377
- dtype = torch.float16 if device == "cuda" else torch.float32
 
 
 
 
378
  id_emb = speech_z.to(dtype)
379
  id_emb_projected = project_face_embs(pipeline, id_emb)
380
 
381
- images = []
382
- for i in range(int(num_samples)):
383
- seed = int(base_seed) + i
384
- generator = torch.Generator(device=device).manual_seed(seed)
385
- img = pipeline(
386
- prompt_embeds=id_emb_projected,
387
- num_inference_steps=int(num_inference_steps),
388
- guidance_scale=float(guidance_scale),
389
- num_images_per_prompt=1,
390
- generator=generator,
391
- ).images[0]
392
- images.append(img)
 
 
393
 
394
  if select_best:
395
- model_ready = facenet_model is not None if best_selection in ("mean", "pairwise") else facenet_classify_model is not None
 
 
 
 
 
396
  if model_ready:
397
  best = select_best_image(images, best_selection)
398
  else:
399
  best = images[0]
 
400
  return [best], ""
401
 
402
  return images, ""
403
 
404
-
405
  # ---------------------------------------------------------------------------
406
  # Model loading
407
  # ---------------------------------------------------------------------------
 
362
  def generate(audio_path, num_samples, guidance_scale, num_inference_steps, base_seed, select_best, best_selection="pairwise"):
363
  global pipeline, speaker_encoder, facenet_model, facenet_classify_model, device
364
 
 
 
365
  if audio_path is None:
366
  return None, "Please provide an audio file."
367
 
368
+ device = "cuda" if torch.cuda.is_available() else "cpu"
369
+ print(f"[generate] device = {device}")
370
+
371
+ if pipeline is None or speaker_encoder is None:
372
+ print("[generate] Loading models lazily...")
373
+ load_models()
374
+ print("[generate] Models loaded.")
375
+
376
+ if pipeline is None or speaker_encoder is None:
377
+ return None, "Model loading failed. Check logs."
378
+
379
  try:
380
  waveform = load_and_process_audio(audio_path, device, max_seconds=5.0)
381
  except Exception as e:
382
  return None, f"Audio loading failed: {e}"
383
 
384
+ dtype = torch.float16 if device == "cuda" else torch.float32
385
+
386
  with torch.no_grad():
387
+ speech_z = speaker_encoder(
388
+ waveform,
389
+ normalize=True,
390
+ apply_shared_projection=False,
391
+ )
392
+
393
  id_emb = speech_z.to(dtype)
394
  id_emb_projected = project_face_embs(pipeline, id_emb)
395
 
396
+ images = []
397
+ for i in range(int(num_samples)):
398
+ seed = int(base_seed) + i
399
+ generator = torch.Generator(device=device).manual_seed(seed)
400
+
401
+ img = pipeline(
402
+ prompt_embeds=id_emb_projected,
403
+ num_inference_steps=int(num_inference_steps),
404
+ guidance_scale=float(guidance_scale),
405
+ num_images_per_prompt=1,
406
+ generator=generator,
407
+ ).images[0]
408
+
409
+ images.append(img)
410
 
411
  if select_best:
412
+ model_ready = (
413
+ facenet_model is not None
414
+ if best_selection in ("mean", "pairwise")
415
+ else facenet_classify_model is not None
416
+ )
417
+
418
  if model_ready:
419
  best = select_best_image(images, best_selection)
420
  else:
421
  best = images[0]
422
+
423
  return [best], ""
424
 
425
  return images, ""
426
 
 
427
  # ---------------------------------------------------------------------------
428
  # Model loading
429
  # ---------------------------------------------------------------------------