import os import spaces import sys os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["MKL_THREADING_LAYER"] = "GNU" # --------------------------------------------------------------------------- # Configuration — set CHECKPOINT_REPO as a HuggingFace Space secret to load # fine-tuned models. If left empty, the demo uses base Arc2Face with a raw # WavLM x-vector encoder (useful for testing that the Space works). # --------------------------------------------------------------------------- CHECKPOINT_REPO = os.environ.get("CHECKPOINT_REPO", "") ENCODER_FILENAME = os.environ.get("ENCODER_FILENAME", "speaker_encoder.pt") ARC2FACE_REPO = "FoivosPar/Arc2Face" BASE_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5" SKIP_LORA = not bool(CHECKPOINT_REPO) SKIP_SPEAKER_ENCODER = not bool(CHECKPOINT_REPO) import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchaudio from PIL import Image from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DPMSolverMultistepScheduler from huggingface_hub import snapshot_download, hf_hub_download import gradio as gr from external.arc2face import CLIPTextModelWrapper, project_face_embs from core.models.encoder.speech_face_encoder import SpeechFaceXVectorEncoder # --------------------------------------------------------------------------- # Globals populated at startup # --------------------------------------------------------------------------- pipeline = None speaker_encoder = None facenet_model = None facenet_classify_model = None mtcnn_model = None device = "cuda" if torch.cuda.is_available() else "cpu" # --------------------------------------------------------------------------- # PEFT-compatible attention processors (inlined from core/factories/lora_factory.py) # These fix "Linear.forward() takes 2 positional arguments but 3 were given" # when using LoRA-wrapped UNet attention layers. # --------------------------------------------------------------------------- class PeftCompatibleAttnProcessor: def __call__( self, attn, hidden_states: torch.Tensor, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs, ) -> torch.Tensor: residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class PeftCompatibleAttnProcessor2_0: def __init__(self): if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): raise ImportError("PeftCompatibleAttnProcessor2_0 requires PyTorch 2.0+.") def __call__( self, attn, hidden_states: torch.Tensor, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs, ) -> torch.Tensor: residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) hidden_states = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def _set_attn_processor_for_lora(unet: nn.Module) -> None: try: attn_procs = {} for name in unet.attn_processors.keys(): if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): attn_procs[name] = PeftCompatibleAttnProcessor2_0() else: attn_procs[name] = PeftCompatibleAttnProcessor() unet.set_attn_processor(attn_procs) print(" Set PEFT-compatible attention processors") except Exception as e: print(f" Warning: Could not set attention processors for LoRA: {e}") # --------------------------------------------------------------------------- # Utilities # --------------------------------------------------------------------------- MIN_AUDIO_SECONDS = 5.0 def load_and_process_audio(audio_file: str, dev: str, max_seconds: float = 6.0): try: waveform, sample_rate = torchaudio.load(audio_file) except Exception: import soundfile as sf data, sample_rate = sf.read(audio_file, always_2d=True) waveform = torch.from_numpy(data.T.astype(np.float32)) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) duration = waveform.shape[1] / 16000 if duration < MIN_AUDIO_SECONDS: raise ValueError(f"Audio is too short ({duration:.1f}s). Please provide at least {MIN_AUDIO_SECONDS:.0f} seconds of speech.") max_samples = int(max_seconds * 16000) if waveform.shape[1] > max_samples: waveform = waveform[:, :max_samples] elif waveform.shape[1] < max_samples: waveform = F.pad(waveform, (0, max_samples - waveform.shape[1])) return waveform.squeeze(0).unsqueeze(0).to(dev) def is_lora_checkpoint(checkpoint_path: str, subfolder: str) -> bool: return os.path.exists(os.path.join(checkpoint_path, subfolder, "adapter_config.json")) def resolve_checkpoint_path(checkpoint_path: str) -> str: checkpoint_path = os.path.expanduser(checkpoint_path) if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint path does not exist: {checkpoint_path}") expected_subs = {"encoder", "unet"} if os.path.isdir(checkpoint_path): children = set(os.listdir(checkpoint_path)) if expected_subs.issubset(children): return checkpoint_path ckpts = [d for d in os.listdir(checkpoint_path) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(checkpoint_path, d))] if not ckpts: return checkpoint_path def ckpt_num(name): try: return int(name.split("checkpoint-")[-1]) except Exception: return -1 return os.path.join(checkpoint_path, sorted(ckpts, key=ckpt_num)[-1]) return checkpoint_path # --------------------------------------------------------------------------- # LoRA checkpoint loading # --------------------------------------------------------------------------- def load_encoder_with_lora(checkpoint_path: str): encoder_path = os.path.join(checkpoint_path, "lora", "encoder") if is_lora_checkpoint(checkpoint_path, os.path.join("lora", "encoder")): from peft import PeftModel base_encoder = CLIPTextModelWrapper.from_pretrained(ARC2FACE_REPO, subfolder='encoder') encoder = PeftModel.from_pretrained(base_encoder, encoder_path) encoder = encoder.merge_and_unload() encoder.forward = base_encoder.forward return encoder return CLIPTextModelWrapper.from_pretrained(checkpoint_path, subfolder="encoder") def load_unet_with_lora(checkpoint_path: str): unet_path = os.path.join(checkpoint_path, "lora", "unet") if is_lora_checkpoint(checkpoint_path, os.path.join("lora", "unet")): from peft import PeftModel base_unet = UNet2DConditionModel.from_pretrained(ARC2FACE_REPO, subfolder='arc2face') unet = PeftModel.from_pretrained(base_unet, unet_path) unet = unet.merge_and_unload() unet.forward = base_unet.forward _set_attn_processor_for_lora(unet) return unet return UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet") # --------------------------------------------------------------------------- # Raw WavLM encoder (fallback when no fine-tuned checkpoint is provided) # --------------------------------------------------------------------------- class RawWavLMEncoder: def __init__(self, pretrained_path: str, dev: str): from transformers import WavLMForXVector self.wavlm_xvector = WavLMForXVector.from_pretrained(pretrained_path).to(dev) self.wavlm_xvector.eval() def __call__(self, waveform, normalize=True, apply_shared_projection=False): emb = self.wavlm_xvector(input_values=waveform, return_dict=True).embeddings if normalize: emb = F.normalize(emb, p=2, dim=1) return emb def eval(self): self.wavlm_xvector.eval() return self def to(self, dev): self.wavlm_xvector = self.wavlm_xvector.to(dev) return self # --------------------------------------------------------------------------- # FaceNet best-sample selection # --------------------------------------------------------------------------- def _facenet_transform(): from torchvision import transforms return transforms.Compose([ transforms.Resize((160, 160)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) def _extract_facenet_emb(img: Image.Image, model) -> torch.Tensor: tensor = _facenet_transform()(img.convert("RGB")).unsqueeze(0) with torch.no_grad(): emb = model(tensor) return F.normalize(emb.squeeze(0), p=2, dim=0) def _extract_facenet_logits(img: Image.Image, model) -> torch.Tensor: tensor = _facenet_transform()(img.convert("RGB")).unsqueeze(0) with torch.no_grad(): logits = model(tensor) return logits.squeeze(0) def select_best_images(pairs: list, n: int) -> list: """pairs: list of (image, seed). Returns top-n (image, seed) pairs.""" global facenet_model n = min(n, len(pairs)) images = [p[0] for p in pairs] if facenet_model is None: return pairs[:n] embeddings = torch.stack([_extract_facenet_emb(img, facenet_model) for img in images]) sim_matrix = F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2) avg_sims = (sim_matrix.sum(dim=1) - 1) / (len(images) - 1) top_indices = avg_sims.argsort(descending=True)[:n].tolist() print(f"[select_best:pairwise] top {n} indices={top_indices} avg_sims={avg_sims[top_indices].tolist()}") return [pairs[i] for i in top_indices] def select_best_images_combined(pairs: list, n: int) -> list: """pairs: list of (image, seed). Returns top-n (image, seed) pairs.""" global mtcnn_model, facenet_classify_model n = min(n, len(pairs)) if mtcnn_model is None or facenet_classify_model is None: print("[select_best:combined] models unavailable, falling back to pairwise") return select_best_images(pairs, n) scores = [] for idx, (img, _) in enumerate(pairs): _, probs = mtcnn_model.detect(img) det_conf = float(probs[0]) if probs is not None and probs[0] is not None else 0.0 tensor = _facenet_transform()(img.convert("RGB")).unsqueeze(0) with torch.no_grad(): logits = facenet_classify_model(tensor) classify_conf = float(F.softmax(logits, dim=1).max(dim=1).values[0]) combined = det_conf * classify_conf scores.append(combined) print(f" [combined] idx={idx} det={det_conf:.3f} classify={classify_conf:.3f} combined={combined:.3f}") top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n] print(f"[select_best:combined] top {n} indices={top_indices} scores={[scores[i] for i in top_indices]}") return [pairs[i] for i in top_indices] SELECTION_METHODS = ["Pairwise similarity", "Detection + Classify confidence"] DEFAULT_SELECTION_METHOD = SELECTION_METHODS[0] # --------------------------------------------------------------------------- # Generation # --------------------------------------------------------------------------- #GENERATION_SEEDS = [42, 48, 56, 63, 74, 84, 86] #107, 119 GENERATION_SEEDS = [42, 48, 56, 63, 74, 84, 86, 107, 119, 124, 125, 127, 128, 129] INTERNAL_SAMPLES = len(GENERATION_SEEDS) SHOW_SEED_CAPTIONS = False RANDOM_SEED_SELECTION = True # If True, randomly pick DEFAULT_NUM_DISPLAY seeds to generate (faster). If False, generate all seeds and rank by quality. @spaces.GPU(duration=120) def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_seed, selection_method=DEFAULT_SELECTION_METHOD): global pipeline, speaker_encoder, facenet_model, device if audio_path is None: return None, "Please provide an audio file." device = "cuda" if torch.cuda.is_available() else "cpu" print(f"[generate] device = {device}") if pipeline is None or speaker_encoder is None: print("[generate] Loading models lazily...") load_models() print("[generate] Models loaded.") if pipeline is None or speaker_encoder is None: return None, "Model loading failed. Check logs." try: waveform = load_and_process_audio(audio_path, device, max_seconds=5.0) except Exception as e: return None, f"Audio loading failed: {e}" dtype = torch.float16 if device == "cuda" else torch.float32 with torch.no_grad(): speech_z = speaker_encoder( waveform, normalize=True, apply_shared_projection=False, ) id_emb = speech_z.to(dtype) id_emb_projected = project_face_embs(pipeline, id_emb) n = int(num_display) seeds_to_run = ( np.random.choice(GENERATION_SEEDS, size=min(n, len(GENERATION_SEEDS)), replace=False).tolist() if RANDOM_SEED_SELECTION else GENERATION_SEEDS ) print(f"[generate] seeds_to_run={seeds_to_run}") pairs = [] for seed in seeds_to_run: generator = torch.Generator(device=device).manual_seed(seed) img = pipeline( prompt_embeds=id_emb_projected, num_inference_steps=int(num_inference_steps), guidance_scale=float(guidance_scale), num_images_per_prompt=1, generator=generator, ).images[0] pairs.append((img, seed)) if RANDOM_SEED_SELECTION: best = pairs elif selection_method == "Detection + Classify confidence": best = select_best_images_combined(pairs, int(num_display)) else: best = select_best_images(pairs, int(num_display)) return [(img, f"Seed: {seed}") if SHOW_SEED_CAPTIONS else img for img, seed in best], "" # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- def load_models(): global pipeline, speaker_encoder, facenet_model, facenet_classify_model, mtcnn_model, device dtype = torch.float16 if device == "cuda" else torch.float32 device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Speaker encoder print("Loading speaker encoder...") if SKIP_SPEAKER_ENCODER: speaker_encoder = RawWavLMEncoder("microsoft/wavlm-base-sv", device) print(" Using raw WavLM x-vector encoder (no fine-tuned checkpoint)") else: enc = SpeechFaceXVectorEncoder( pretrained_path="microsoft/wavlm-base-sv", face_emb_dim=512, dropout=0.0, use_projection=True, freeze_feature_encoder=True, ) encoder_pt = hf_hub_download(CHECKPOINT_REPO, ENCODER_FILENAME) ckpt = torch.load(encoder_pt, map_location=device, weights_only=False) enc.load_state_dict(ckpt["model"], strict=False) speaker_encoder = enc.to(device).eval() print(f" Loaded from {CHECKPOINT_REPO}/{ENCODER_FILENAME}") # Diffusion pipeline print("Loading diffusion pipeline...") if SKIP_LORA: encoder = CLIPTextModelWrapper.from_pretrained(ARC2FACE_REPO, subfolder='encoder', torch_dtype=dtype) unet = UNet2DConditionModel.from_pretrained(ARC2FACE_REPO, subfolder='arc2face', torch_dtype=dtype) print(" Using base Arc2Face (no LoRA)") else: checkpoint_dir = snapshot_download(CHECKPOINT_REPO) checkpoint = resolve_checkpoint_path(checkpoint_dir) print(f" Checkpoint: {checkpoint}") encoder = load_encoder_with_lora(checkpoint).to(dtype=dtype) unet = load_unet_with_lora(checkpoint).to(dtype=dtype) pipeline = StableDiffusionPipeline.from_pretrained( BASE_MODEL, text_encoder=encoder, unet=unet, torch_dtype=dtype, safety_checker=None, ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(device) print(" Pipeline ready") # FaceNet + MTCNN for best-sample selection print("Loading FaceNet + MTCNN for best-sample selection...") try: from facenet_pytorch import InceptionResnetV1, MTCNN facenet_model = InceptionResnetV1(pretrained='vggface2', classify=False).eval() facenet_classify_model = InceptionResnetV1(pretrained='vggface2', classify=True).eval() mtcnn_model = MTCNN(keep_all=False, device='cpu') print(" FaceNet + MTCNN ready") except Exception as e: print(f" FaceNet/MTCNN unavailable ({e}); select-best will fall back to first image") facenet_model = None facenet_classify_model = None mtcnn_model = None # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- def build_demo(): with gr.Blocks(title="AAS2F: Ambiguity-Aware Speech-to-Face Synthesis with Speaker-Conditioned Diffusion Models") as demo: gr.Markdown("# AAS2F: Ambiguity-Aware Speech-to-Face Synthesis with Speaker-Conditioned Diffusion Models") gr.Markdown( "**Steps to use the demo:**\n\n" "1. Upload or record a speech audio clip. **Please provide at least 5 seconds of speech.**\n" "2. Note that it works best with **English**, but should work with other languages as well.\n" "3. After you are done recording/uploading the audio, click the 'Generate' button to start the generation process.\n" "4. After a few seconds, the generated images will be displayed on the right." ) DEFAULT_NUM_DISPLAY = 3 DEFAULT_GUIDANCE_SCALE = 2.5 DEFAULT_NUM_STEPS = 50 DEFAULT_BASE_SEED = 42 with gr.Row(): with gr.Column(): with gr.Row(): audio_upload = gr.Audio( sources=["upload"], type="filepath", label="Upload Audio", ) audio_mic = gr.Audio( sources=["microphone"], type="filepath", label="Record Audio", ) generate_btn = gr.Button("Generate", variant="primary", interactive=False) with gr.Column(): gallery = gr.Gallery(label="Generated Images") status = gr.HTML(visible=False) def _update_btn(upload, mic): return gr.update(interactive=(upload is not None or mic is not None)) audio_upload.change(fn=_update_btn, inputs=[audio_upload, audio_mic], outputs=generate_btn) audio_mic.change(fn=_update_btn, inputs=[audio_upload, audio_mic], outputs=generate_btn) def _generate(upload, mic): audio = upload if upload is not None else mic imgs, msg = generate(audio, DEFAULT_NUM_DISPLAY, DEFAULT_GUIDANCE_SCALE, DEFAULT_NUM_STEPS, DEFAULT_BASE_SEED) if msg: error_html = f'