AAS2F / app.py
TaliDror
improved UI
a5ca1f0
import os
import spaces
import sys
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["MKL_THREADING_LAYER"] = "GNU"
# ---------------------------------------------------------------------------
# Configuration — set CHECKPOINT_REPO as a HuggingFace Space secret to load
# fine-tuned models. If left empty, the demo uses base Arc2Face with a raw
# WavLM x-vector encoder (useful for testing that the Space works).
# ---------------------------------------------------------------------------
CHECKPOINT_REPO = os.environ.get("CHECKPOINT_REPO", "")
ENCODER_FILENAME = os.environ.get("ENCODER_FILENAME", "speaker_encoder.pt")
ARC2FACE_REPO = "FoivosPar/Arc2Face"
BASE_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5"
SKIP_LORA = not bool(CHECKPOINT_REPO)
SKIP_SPEAKER_ENCODER = not bool(CHECKPOINT_REPO)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from PIL import Image
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DPMSolverMultistepScheduler
from huggingface_hub import snapshot_download, hf_hub_download
import gradio as gr
from external.arc2face import CLIPTextModelWrapper, project_face_embs
from core.models.encoder.speech_face_encoder import SpeechFaceXVectorEncoder
# ---------------------------------------------------------------------------
# Globals populated at startup
# ---------------------------------------------------------------------------
pipeline = None
speaker_encoder = None
facenet_model = None
facenet_classify_model = None
mtcnn_model = None
device = "cuda" if torch.cuda.is_available() else "cpu"
# ---------------------------------------------------------------------------
# PEFT-compatible attention processors (inlined from core/factories/lora_factory.py)
# These fix "Linear.forward() takes 2 positional arguments but 3 were given"
# when using LoRA-wrapped UNet attention layers.
# ---------------------------------------------------------------------------
class PeftCompatibleAttnProcessor:
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class PeftCompatibleAttnProcessor2_0:
def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("PeftCompatibleAttnProcessor2_0 requires PyTorch 2.0+.")
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def _set_attn_processor_for_lora(unet: nn.Module) -> None:
try:
attn_procs = {}
for name in unet.attn_processors.keys():
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
attn_procs[name] = PeftCompatibleAttnProcessor2_0()
else:
attn_procs[name] = PeftCompatibleAttnProcessor()
unet.set_attn_processor(attn_procs)
print(" Set PEFT-compatible attention processors")
except Exception as e:
print(f" Warning: Could not set attention processors for LoRA: {e}")
# ---------------------------------------------------------------------------
# Utilities
# ---------------------------------------------------------------------------
MIN_AUDIO_SECONDS = 5.0
def load_and_process_audio(audio_file: str, dev: str, max_seconds: float = 6.0):
try:
waveform, sample_rate = torchaudio.load(audio_file)
except Exception:
import soundfile as sf
data, sample_rate = sf.read(audio_file, always_2d=True)
waveform = torch.from_numpy(data.T.astype(np.float32))
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
duration = waveform.shape[1] / 16000
if duration < MIN_AUDIO_SECONDS:
raise ValueError(f"Audio is too short ({duration:.1f}s). Please provide at least {MIN_AUDIO_SECONDS:.0f} seconds of speech.")
max_samples = int(max_seconds * 16000)
if waveform.shape[1] > max_samples:
waveform = waveform[:, :max_samples]
elif waveform.shape[1] < max_samples:
waveform = F.pad(waveform, (0, max_samples - waveform.shape[1]))
return waveform.squeeze(0).unsqueeze(0).to(dev)
def is_lora_checkpoint(checkpoint_path: str, subfolder: str) -> bool:
return os.path.exists(os.path.join(checkpoint_path, subfolder, "adapter_config.json"))
def resolve_checkpoint_path(checkpoint_path: str) -> str:
checkpoint_path = os.path.expanduser(checkpoint_path)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint path does not exist: {checkpoint_path}")
expected_subs = {"encoder", "unet"}
if os.path.isdir(checkpoint_path):
children = set(os.listdir(checkpoint_path))
if expected_subs.issubset(children):
return checkpoint_path
ckpts = [d for d in os.listdir(checkpoint_path)
if d.startswith("checkpoint-") and os.path.isdir(os.path.join(checkpoint_path, d))]
if not ckpts:
return checkpoint_path
def ckpt_num(name):
try:
return int(name.split("checkpoint-")[-1])
except Exception:
return -1
return os.path.join(checkpoint_path, sorted(ckpts, key=ckpt_num)[-1])
return checkpoint_path
# ---------------------------------------------------------------------------
# LoRA checkpoint loading
# ---------------------------------------------------------------------------
def load_encoder_with_lora(checkpoint_path: str):
encoder_path = os.path.join(checkpoint_path, "lora", "encoder")
if is_lora_checkpoint(checkpoint_path, os.path.join("lora", "encoder")):
from peft import PeftModel
base_encoder = CLIPTextModelWrapper.from_pretrained(ARC2FACE_REPO, subfolder='encoder')
encoder = PeftModel.from_pretrained(base_encoder, encoder_path)
encoder = encoder.merge_and_unload()
encoder.forward = base_encoder.forward
return encoder
return CLIPTextModelWrapper.from_pretrained(checkpoint_path, subfolder="encoder")
def load_unet_with_lora(checkpoint_path: str):
unet_path = os.path.join(checkpoint_path, "lora", "unet")
if is_lora_checkpoint(checkpoint_path, os.path.join("lora", "unet")):
from peft import PeftModel
base_unet = UNet2DConditionModel.from_pretrained(ARC2FACE_REPO, subfolder='arc2face')
unet = PeftModel.from_pretrained(base_unet, unet_path)
unet = unet.merge_and_unload()
unet.forward = base_unet.forward
_set_attn_processor_for_lora(unet)
return unet
return UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet")
# ---------------------------------------------------------------------------
# Raw WavLM encoder (fallback when no fine-tuned checkpoint is provided)
# ---------------------------------------------------------------------------
class RawWavLMEncoder:
def __init__(self, pretrained_path: str, dev: str):
from transformers import WavLMForXVector
self.wavlm_xvector = WavLMForXVector.from_pretrained(pretrained_path).to(dev)
self.wavlm_xvector.eval()
def __call__(self, waveform, normalize=True, apply_shared_projection=False):
emb = self.wavlm_xvector(input_values=waveform, return_dict=True).embeddings
if normalize:
emb = F.normalize(emb, p=2, dim=1)
return emb
def eval(self):
self.wavlm_xvector.eval()
return self
def to(self, dev):
self.wavlm_xvector = self.wavlm_xvector.to(dev)
return self
# ---------------------------------------------------------------------------
# FaceNet best-sample selection
# ---------------------------------------------------------------------------
def _facenet_transform():
from torchvision import transforms
return transforms.Compose([
transforms.Resize((160, 160)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
def _extract_facenet_emb(img: Image.Image, model) -> torch.Tensor:
tensor = _facenet_transform()(img.convert("RGB")).unsqueeze(0)
with torch.no_grad():
emb = model(tensor)
return F.normalize(emb.squeeze(0), p=2, dim=0)
def _extract_facenet_logits(img: Image.Image, model) -> torch.Tensor:
tensor = _facenet_transform()(img.convert("RGB")).unsqueeze(0)
with torch.no_grad():
logits = model(tensor)
return logits.squeeze(0)
def select_best_images(pairs: list, n: int) -> list:
"""pairs: list of (image, seed). Returns top-n (image, seed) pairs."""
global facenet_model
n = min(n, len(pairs))
images = [p[0] for p in pairs]
if facenet_model is None:
return pairs[:n]
embeddings = torch.stack([_extract_facenet_emb(img, facenet_model) for img in images])
sim_matrix = F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
avg_sims = (sim_matrix.sum(dim=1) - 1) / (len(images) - 1)
top_indices = avg_sims.argsort(descending=True)[:n].tolist()
print(f"[select_best:pairwise] top {n} indices={top_indices} avg_sims={avg_sims[top_indices].tolist()}")
return [pairs[i] for i in top_indices]
def select_best_images_combined(pairs: list, n: int) -> list:
"""pairs: list of (image, seed). Returns top-n (image, seed) pairs."""
global mtcnn_model, facenet_classify_model
n = min(n, len(pairs))
if mtcnn_model is None or facenet_classify_model is None:
print("[select_best:combined] models unavailable, falling back to pairwise")
return select_best_images(pairs, n)
scores = []
for idx, (img, _) in enumerate(pairs):
_, probs = mtcnn_model.detect(img)
det_conf = float(probs[0]) if probs is not None and probs[0] is not None else 0.0
tensor = _facenet_transform()(img.convert("RGB")).unsqueeze(0)
with torch.no_grad():
logits = facenet_classify_model(tensor)
classify_conf = float(F.softmax(logits, dim=1).max(dim=1).values[0])
combined = det_conf * classify_conf
scores.append(combined)
print(f" [combined] idx={idx} det={det_conf:.3f} classify={classify_conf:.3f} combined={combined:.3f}")
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n]
print(f"[select_best:combined] top {n} indices={top_indices} scores={[scores[i] for i in top_indices]}")
return [pairs[i] for i in top_indices]
SELECTION_METHODS = ["Pairwise similarity", "Detection + Classify confidence"]
DEFAULT_SELECTION_METHOD = SELECTION_METHODS[0]
# ---------------------------------------------------------------------------
# Generation
# ---------------------------------------------------------------------------
#GENERATION_SEEDS = [42, 48, 56, 63, 74, 84, 86]
#107, 119
GENERATION_SEEDS = [42, 48, 56, 63, 74, 84, 86, 107, 119, 124, 125, 127, 128, 129]
INTERNAL_SAMPLES = len(GENERATION_SEEDS)
SHOW_SEED_CAPTIONS = False
RANDOM_SEED_SELECTION = True # If True, randomly pick DEFAULT_NUM_DISPLAY seeds to generate (faster). If False, generate all seeds and rank by quality.
@spaces.GPU(duration=120)
def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_seed, selection_method=DEFAULT_SELECTION_METHOD):
global pipeline, speaker_encoder, facenet_model, device
if audio_path is None:
return None, "Please provide an audio file."
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[generate] device = {device}")
if pipeline is None or speaker_encoder is None:
print("[generate] Loading models lazily...")
load_models()
print("[generate] Models loaded.")
if pipeline is None or speaker_encoder is None:
return None, "Model loading failed. Check logs."
try:
waveform = load_and_process_audio(audio_path, device, max_seconds=5.0)
except Exception as e:
return None, f"Audio loading failed: {e}"
dtype = torch.float16 if device == "cuda" else torch.float32
with torch.no_grad():
speech_z = speaker_encoder(
waveform,
normalize=True,
apply_shared_projection=False,
)
id_emb = speech_z.to(dtype)
id_emb_projected = project_face_embs(pipeline, id_emb)
n = int(num_display)
seeds_to_run = (
np.random.choice(GENERATION_SEEDS, size=min(n, len(GENERATION_SEEDS)), replace=False).tolist()
if RANDOM_SEED_SELECTION else GENERATION_SEEDS
)
print(f"[generate] seeds_to_run={seeds_to_run}")
pairs = []
for seed in seeds_to_run:
generator = torch.Generator(device=device).manual_seed(seed)
img = pipeline(
prompt_embeds=id_emb_projected,
num_inference_steps=int(num_inference_steps),
guidance_scale=float(guidance_scale),
num_images_per_prompt=1,
generator=generator,
).images[0]
pairs.append((img, seed))
if RANDOM_SEED_SELECTION:
best = pairs
elif selection_method == "Detection + Classify confidence":
best = select_best_images_combined(pairs, int(num_display))
else:
best = select_best_images(pairs, int(num_display))
return [(img, f"Seed: {seed}") if SHOW_SEED_CAPTIONS else img for img, seed in best], ""
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def load_models():
global pipeline, speaker_encoder, facenet_model, facenet_classify_model, mtcnn_model, device
dtype = torch.float16 if device == "cuda" else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Speaker encoder
print("Loading speaker encoder...")
if SKIP_SPEAKER_ENCODER:
speaker_encoder = RawWavLMEncoder("microsoft/wavlm-base-sv", device)
print(" Using raw WavLM x-vector encoder (no fine-tuned checkpoint)")
else:
enc = SpeechFaceXVectorEncoder(
pretrained_path="microsoft/wavlm-base-sv",
face_emb_dim=512,
dropout=0.0,
use_projection=True,
freeze_feature_encoder=True,
)
encoder_pt = hf_hub_download(CHECKPOINT_REPO, ENCODER_FILENAME)
ckpt = torch.load(encoder_pt, map_location=device, weights_only=False)
enc.load_state_dict(ckpt["model"], strict=False)
speaker_encoder = enc.to(device).eval()
print(f" Loaded from {CHECKPOINT_REPO}/{ENCODER_FILENAME}")
# Diffusion pipeline
print("Loading diffusion pipeline...")
if SKIP_LORA:
encoder = CLIPTextModelWrapper.from_pretrained(ARC2FACE_REPO, subfolder='encoder', torch_dtype=dtype)
unet = UNet2DConditionModel.from_pretrained(ARC2FACE_REPO, subfolder='arc2face', torch_dtype=dtype)
print(" Using base Arc2Face (no LoRA)")
else:
checkpoint_dir = snapshot_download(CHECKPOINT_REPO)
checkpoint = resolve_checkpoint_path(checkpoint_dir)
print(f" Checkpoint: {checkpoint}")
encoder = load_encoder_with_lora(checkpoint).to(dtype=dtype)
unet = load_unet_with_lora(checkpoint).to(dtype=dtype)
pipeline = StableDiffusionPipeline.from_pretrained(
BASE_MODEL,
text_encoder=encoder,
unet=unet,
torch_dtype=dtype,
safety_checker=None,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(device)
print(" Pipeline ready")
# FaceNet + MTCNN for best-sample selection
print("Loading FaceNet + MTCNN for best-sample selection...")
try:
from facenet_pytorch import InceptionResnetV1, MTCNN
facenet_model = InceptionResnetV1(pretrained='vggface2', classify=False).eval()
facenet_classify_model = InceptionResnetV1(pretrained='vggface2', classify=True).eval()
mtcnn_model = MTCNN(keep_all=False, device='cpu')
print(" FaceNet + MTCNN ready")
except Exception as e:
print(f" FaceNet/MTCNN unavailable ({e}); select-best will fall back to first image")
facenet_model = None
facenet_classify_model = None
mtcnn_model = None
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
def build_demo():
with gr.Blocks(title="AAS2F: Ambiguity-Aware Speech-to-Face Synthesis with Speaker-Conditioned Diffusion Models") as demo:
gr.Markdown("# AAS2F: Ambiguity-Aware Speech-to-Face Synthesis with Speaker-Conditioned Diffusion Models")
gr.Markdown(
"**Steps to use the demo:**\n\n"
"1. Upload or record a speech audio clip. **Please provide at least 5 seconds of speech.**\n"
"2. Note that it works best with **English**, but should work with other languages as well.\n"
"3. After you are done recording/uploading the audio, click the 'Generate' button to start the generation process.\n"
"4. After a few seconds, the generated images will be displayed on the right."
)
DEFAULT_NUM_DISPLAY = 3
DEFAULT_GUIDANCE_SCALE = 2.5
DEFAULT_NUM_STEPS = 50
DEFAULT_BASE_SEED = 42
with gr.Row():
with gr.Column():
with gr.Row():
audio_upload = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload Audio",
)
audio_mic = gr.Audio(
sources=["microphone"],
type="filepath",
label="Record Audio",
)
generate_btn = gr.Button("Generate", variant="primary", interactive=False)
with gr.Column():
gallery = gr.Gallery(label="Generated Images")
status = gr.HTML(visible=False)
def _update_btn(upload, mic):
return gr.update(interactive=(upload is not None or mic is not None))
audio_upload.change(fn=_update_btn, inputs=[audio_upload, audio_mic], outputs=generate_btn)
audio_mic.change(fn=_update_btn, inputs=[audio_upload, audio_mic], outputs=generate_btn)
def _generate(upload, mic):
audio = upload if upload is not None else mic
imgs, msg = generate(audio, DEFAULT_NUM_DISPLAY, DEFAULT_GUIDANCE_SCALE, DEFAULT_NUM_STEPS, DEFAULT_BASE_SEED)
if msg:
error_html = f'<div style="background:#fee2e2;border:1px solid #f87171;border-radius:8px;padding:12px 16px;color:#b91c1c;font-size:0.95em;">⚠️ {msg}</div>'
return imgs, gr.update(value=error_html, visible=True)
return imgs, gr.update(value="", visible=False)
generate_btn.click(
fn=lambda u, m: (gr.update(value="Generating...", interactive=False, variant="secondary"), gr.update(value="", visible=False)),
inputs=[audio_upload, audio_mic],
outputs=[generate_btn, status],
).then(
fn=_generate,
inputs=[audio_upload, audio_mic],
outputs=[gallery, status],
).then(
fn=lambda u, m: gr.update(value="Generate", interactive=(u is not None or m is not None), variant="primary"),
inputs=[audio_upload, audio_mic],
outputs=generate_btn,
)
return demo
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
demo = build_demo()
demo.queue()
demo.launch()