TaliDror commited on
Commit
deb433b
·
1 Parent(s): 626735d

adaptation to enable ZeroGPU

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -25,6 +25,7 @@ from PIL import Image
25
  from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DPMSolverMultistepScheduler
26
  from huggingface_hub import snapshot_download, hf_hub_download
27
  import gradio as gr
 
28
 
29
  from external.arc2face import CLIPTextModelWrapper, project_face_embs
30
  from core.models.encoder.speech_face_encoder import SpeechFaceXVectorEncoder
@@ -357,7 +358,7 @@ def select_best_image(images: list, method: str) -> Image.Image:
357
  # ---------------------------------------------------------------------------
358
  # Generation
359
  # ---------------------------------------------------------------------------
360
-
361
  def generate(audio_path, num_samples, guidance_scale, num_inference_steps, base_seed, select_best, best_selection="pairwise"):
362
  global pipeline, speaker_encoder, facenet_model, facenet_classify_model, device
363
 
@@ -373,7 +374,8 @@ def generate(audio_path, num_samples, guidance_scale, num_inference_steps, base_
373
 
374
  with torch.no_grad():
375
  speech_z = speaker_encoder(waveform, normalize=True, apply_shared_projection=False)
376
- id_emb = speech_z.to(torch.float16)
 
377
  id_emb_projected = project_face_embs(pipeline, id_emb)
378
 
379
  images = []
@@ -406,6 +408,7 @@ def generate(audio_path, num_samples, guidance_scale, num_inference_steps, base_
406
 
407
  def load_models():
408
  global pipeline, speaker_encoder, facenet_model, facenet_classify_model, device
 
409
 
410
  device = "cuda" if torch.cuda.is_available() else "cpu"
411
  print(f"Using device: {device}")
@@ -432,21 +435,21 @@ def load_models():
432
  # Diffusion pipeline
433
  print("Loading diffusion pipeline...")
434
  if SKIP_LORA:
435
- encoder = CLIPTextModelWrapper.from_pretrained(ARC2FACE_REPO, subfolder='encoder', torch_dtype=torch.float16)
436
- unet = UNet2DConditionModel.from_pretrained(ARC2FACE_REPO, subfolder='arc2face', torch_dtype=torch.float16)
437
  print(" Using base Arc2Face (no LoRA)")
438
  else:
439
  checkpoint_dir = snapshot_download(CHECKPOINT_REPO)
440
  checkpoint = resolve_checkpoint_path(checkpoint_dir)
441
  print(f" Checkpoint: {checkpoint}")
442
- encoder = load_encoder_with_lora(checkpoint).to(dtype=torch.float16)
443
- unet = load_unet_with_lora(checkpoint).to(dtype=torch.float16)
444
 
445
  pipeline = StableDiffusionPipeline.from_pretrained(
446
  BASE_MODEL,
447
  text_encoder=encoder,
448
  unet=unet,
449
- torch_dtype=torch.float16,
450
  safety_checker=None,
451
  )
452
  pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -471,7 +474,8 @@ def load_models():
471
  # ---------------------------------------------------------------------------
472
 
473
  def build_demo():
474
- facenet_available = facenet_model is not None and facenet_classify_model is not None
 
475
 
476
  with gr.Blocks(title="Speech-to-Face Generation") as demo:
477
  gr.Markdown("# Speech-to-Face Generation")
@@ -526,7 +530,6 @@ def build_demo():
526
  # Entry point
527
  # ---------------------------------------------------------------------------
528
 
529
- load_models()
530
-
531
  demo = build_demo()
 
532
  demo.launch()
 
25
  from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DPMSolverMultistepScheduler
26
  from huggingface_hub import snapshot_download, hf_hub_download
27
  import gradio as gr
28
+ import spaces
29
 
30
  from external.arc2face import CLIPTextModelWrapper, project_face_embs
31
  from core.models.encoder.speech_face_encoder import SpeechFaceXVectorEncoder
 
358
  # ---------------------------------------------------------------------------
359
  # Generation
360
  # ---------------------------------------------------------------------------
361
+ @spaces.GPU(duration=120)
362
  def generate(audio_path, num_samples, guidance_scale, num_inference_steps, base_seed, select_best, best_selection="pairwise"):
363
  global pipeline, speaker_encoder, facenet_model, facenet_classify_model, device
364
 
 
374
 
375
  with torch.no_grad():
376
  speech_z = speaker_encoder(waveform, normalize=True, apply_shared_projection=False)
377
+ dtype = torch.float16 if device == "cuda" else torch.float32
378
+ id_emb = speech_z.to(dtype)
379
  id_emb_projected = project_face_embs(pipeline, id_emb)
380
 
381
  images = []
 
408
 
409
  def load_models():
410
  global pipeline, speaker_encoder, facenet_model, facenet_classify_model, device
411
+ dtype = torch.float16 if device == "cuda" else torch.float32
412
 
413
  device = "cuda" if torch.cuda.is_available() else "cpu"
414
  print(f"Using device: {device}")
 
435
  # Diffusion pipeline
436
  print("Loading diffusion pipeline...")
437
  if SKIP_LORA:
438
+ encoder = CLIPTextModelWrapper.from_pretrained(ARC2FACE_REPO, subfolder='encoder', torch_dtype=dtype)
439
+ unet = UNet2DConditionModel.from_pretrained(ARC2FACE_REPO, subfolder='arc2face', torch_dtype=dtype)
440
  print(" Using base Arc2Face (no LoRA)")
441
  else:
442
  checkpoint_dir = snapshot_download(CHECKPOINT_REPO)
443
  checkpoint = resolve_checkpoint_path(checkpoint_dir)
444
  print(f" Checkpoint: {checkpoint}")
445
+ encoder = load_encoder_with_lora(checkpoint).to(dtype=dtype)
446
+ unet = load_unet_with_lora(checkpoint).to(dtype=dtype)
447
 
448
  pipeline = StableDiffusionPipeline.from_pretrained(
449
  BASE_MODEL,
450
  text_encoder=encoder,
451
  unet=unet,
452
+ torch_dtype=dtype,
453
  safety_checker=None,
454
  )
455
  pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
 
474
  # ---------------------------------------------------------------------------
475
 
476
  def build_demo():
477
+ #facenet_available = facenet_model is not None and facenet_classify_model is not None
478
+ facenet_available = True
479
 
480
  with gr.Blocks(title="Speech-to-Face Generation") as demo:
481
  gr.Markdown("# Speech-to-Face Generation")
 
530
  # Entry point
531
  # ---------------------------------------------------------------------------
532
 
 
 
533
  demo = build_demo()
534
+ demo.queue()
535
  demo.launch()