| import os |
| import spaces |
| import sys |
| os.environ["OMP_NUM_THREADS"] = "1" |
| os.environ["MKL_NUM_THREADS"] = "1" |
| os.environ["MKL_THREADING_LAYER"] = "GNU" |
|
|
| |
| |
| |
| |
| |
| CHECKPOINT_REPO = os.environ.get("CHECKPOINT_REPO", "") |
| ENCODER_FILENAME = os.environ.get("ENCODER_FILENAME", "speaker_encoder.pt") |
| ARC2FACE_REPO = "FoivosPar/Arc2Face" |
| BASE_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5" |
| SKIP_LORA = not bool(CHECKPOINT_REPO) |
| SKIP_SPEAKER_ENCODER = not bool(CHECKPOINT_REPO) |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchaudio |
| from PIL import Image |
| from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DPMSolverMultistepScheduler |
| from huggingface_hub import snapshot_download, hf_hub_download |
| import gradio as gr |
|
|
| from external.arc2face import CLIPTextModelWrapper, project_face_embs |
| from core.models.encoder.speech_face_encoder import SpeechFaceXVectorEncoder |
|
|
| |
| |
| |
| pipeline = None |
| speaker_encoder = None |
| facenet_model = None |
| facenet_classify_model = None |
| mtcnn_model = None |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| class PeftCompatibleAttnProcessor: |
| def __call__( |
| self, |
| attn, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states=None, |
| attention_mask=None, |
| temb=None, |
| *args, |
| **kwargs, |
| ) -> torch.Tensor: |
| residual = hidden_states |
|
|
| if attn.spatial_norm is not None: |
| hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
| input_ndim = hidden_states.ndim |
|
|
| if input_ndim == 4: |
| batch_size, channel, height, width = hidden_states.shape |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
| batch_size, sequence_length, _ = ( |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| ) |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
| if attn.group_norm is not None: |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
| query = attn.to_q(hidden_states) |
|
|
| if encoder_hidden_states is None: |
| encoder_hidden_states = hidden_states |
| elif attn.norm_cross: |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
| key = attn.to_k(encoder_hidden_states) |
| value = attn.to_v(encoder_hidden_states) |
|
|
| query = attn.head_to_batch_dim(query) |
| key = attn.head_to_batch_dim(key) |
| value = attn.head_to_batch_dim(value) |
|
|
| attention_probs = attn.get_attention_scores(query, key, attention_mask) |
| hidden_states = torch.bmm(attention_probs, value) |
| hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
| hidden_states = attn.to_out[0](hidden_states) |
| hidden_states = attn.to_out[1](hidden_states) |
|
|
| if input_ndim == 4: |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
| if attn.residual_connection: |
| hidden_states = hidden_states + residual |
|
|
| hidden_states = hidden_states / attn.rescale_output_factor |
| return hidden_states |
|
|
|
|
| class PeftCompatibleAttnProcessor2_0: |
| def __init__(self): |
| if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): |
| raise ImportError("PeftCompatibleAttnProcessor2_0 requires PyTorch 2.0+.") |
|
|
| def __call__( |
| self, |
| attn, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states=None, |
| attention_mask=None, |
| temb=None, |
| *args, |
| **kwargs, |
| ) -> torch.Tensor: |
| residual = hidden_states |
|
|
| if attn.spatial_norm is not None: |
| hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
| input_ndim = hidden_states.ndim |
|
|
| if input_ndim == 4: |
| batch_size, channel, height, width = hidden_states.shape |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
| batch_size, sequence_length, _ = ( |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| ) |
|
|
| if attention_mask is not None: |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
| if attn.group_norm is not None: |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
| query = attn.to_q(hidden_states) |
|
|
| if encoder_hidden_states is None: |
| encoder_hidden_states = hidden_states |
| elif attn.norm_cross: |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
| key = attn.to_k(encoder_hidden_states) |
| value = attn.to_v(encoder_hidden_states) |
|
|
| inner_dim = key.shape[-1] |
| head_dim = inner_dim // attn.heads |
|
|
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
| hidden_states = torch.nn.functional.scaled_dot_product_attention( |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
| ) |
|
|
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
| hidden_states = hidden_states.to(query.dtype) |
|
|
| hidden_states = attn.to_out[0](hidden_states) |
| hidden_states = attn.to_out[1](hidden_states) |
|
|
| if input_ndim == 4: |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
| if attn.residual_connection: |
| hidden_states = hidden_states + residual |
|
|
| hidden_states = hidden_states / attn.rescale_output_factor |
| return hidden_states |
|
|
|
|
| def _set_attn_processor_for_lora(unet: nn.Module) -> None: |
| try: |
| attn_procs = {} |
| for name in unet.attn_processors.keys(): |
| if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): |
| attn_procs[name] = PeftCompatibleAttnProcessor2_0() |
| else: |
| attn_procs[name] = PeftCompatibleAttnProcessor() |
| unet.set_attn_processor(attn_procs) |
| print(" Set PEFT-compatible attention processors") |
| except Exception as e: |
| print(f" Warning: Could not set attention processors for LoRA: {e}") |
|
|
|
|
| |
| |
| |
|
|
| MIN_AUDIO_SECONDS = 5.0 |
|
|
| def load_and_process_audio(audio_file: str, dev: str, max_seconds: float = 6.0): |
| try: |
| waveform, sample_rate = torchaudio.load(audio_file) |
| except Exception: |
| import soundfile as sf |
| data, sample_rate = sf.read(audio_file, always_2d=True) |
| waveform = torch.from_numpy(data.T.astype(np.float32)) |
| if sample_rate != 16000: |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) |
| waveform = resampler(waveform) |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
| duration = waveform.shape[1] / 16000 |
| if duration < MIN_AUDIO_SECONDS: |
| raise ValueError(f"Audio is too short ({duration:.1f}s). Please provide at least {MIN_AUDIO_SECONDS:.0f} seconds of speech.") |
| max_samples = int(max_seconds * 16000) |
| if waveform.shape[1] > max_samples: |
| waveform = waveform[:, :max_samples] |
| elif waveform.shape[1] < max_samples: |
| waveform = F.pad(waveform, (0, max_samples - waveform.shape[1])) |
| return waveform.squeeze(0).unsqueeze(0).to(dev) |
|
|
|
|
| def is_lora_checkpoint(checkpoint_path: str, subfolder: str) -> bool: |
| return os.path.exists(os.path.join(checkpoint_path, subfolder, "adapter_config.json")) |
|
|
|
|
| def resolve_checkpoint_path(checkpoint_path: str) -> str: |
| checkpoint_path = os.path.expanduser(checkpoint_path) |
| if not os.path.exists(checkpoint_path): |
| raise FileNotFoundError(f"Checkpoint path does not exist: {checkpoint_path}") |
| expected_subs = {"encoder", "unet"} |
| if os.path.isdir(checkpoint_path): |
| children = set(os.listdir(checkpoint_path)) |
| if expected_subs.issubset(children): |
| return checkpoint_path |
| ckpts = [d for d in os.listdir(checkpoint_path) |
| if d.startswith("checkpoint-") and os.path.isdir(os.path.join(checkpoint_path, d))] |
| if not ckpts: |
| return checkpoint_path |
|
|
| def ckpt_num(name): |
| try: |
| return int(name.split("checkpoint-")[-1]) |
| except Exception: |
| return -1 |
| return os.path.join(checkpoint_path, sorted(ckpts, key=ckpt_num)[-1]) |
| return checkpoint_path |
|
|
|
|
| |
| |
| |
|
|
| def load_encoder_with_lora(checkpoint_path: str): |
| encoder_path = os.path.join(checkpoint_path, "lora", "encoder") |
| if is_lora_checkpoint(checkpoint_path, os.path.join("lora", "encoder")): |
| from peft import PeftModel |
| base_encoder = CLIPTextModelWrapper.from_pretrained(ARC2FACE_REPO, subfolder='encoder') |
| encoder = PeftModel.from_pretrained(base_encoder, encoder_path) |
| encoder = encoder.merge_and_unload() |
| encoder.forward = base_encoder.forward |
| return encoder |
| return CLIPTextModelWrapper.from_pretrained(checkpoint_path, subfolder="encoder") |
|
|
|
|
| def load_unet_with_lora(checkpoint_path: str): |
| unet_path = os.path.join(checkpoint_path, "lora", "unet") |
| if is_lora_checkpoint(checkpoint_path, os.path.join("lora", "unet")): |
| from peft import PeftModel |
| base_unet = UNet2DConditionModel.from_pretrained(ARC2FACE_REPO, subfolder='arc2face') |
| unet = PeftModel.from_pretrained(base_unet, unet_path) |
| unet = unet.merge_and_unload() |
| unet.forward = base_unet.forward |
| _set_attn_processor_for_lora(unet) |
| return unet |
| return UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet") |
|
|
|
|
| |
| |
| |
|
|
| class RawWavLMEncoder: |
| def __init__(self, pretrained_path: str, dev: str): |
| from transformers import WavLMForXVector |
| self.wavlm_xvector = WavLMForXVector.from_pretrained(pretrained_path).to(dev) |
| self.wavlm_xvector.eval() |
|
|
| def __call__(self, waveform, normalize=True, apply_shared_projection=False): |
| emb = self.wavlm_xvector(input_values=waveform, return_dict=True).embeddings |
| if normalize: |
| emb = F.normalize(emb, p=2, dim=1) |
| return emb |
|
|
| def eval(self): |
| self.wavlm_xvector.eval() |
| return self |
|
|
| def to(self, dev): |
| self.wavlm_xvector = self.wavlm_xvector.to(dev) |
| return self |
|
|
|
|
| |
| |
| |
|
|
| def _facenet_transform(): |
| from torchvision import transforms |
| return transforms.Compose([ |
| transforms.Resize((160, 160)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ]) |
|
|
|
|
| def _extract_facenet_emb(img: Image.Image, model) -> torch.Tensor: |
| tensor = _facenet_transform()(img.convert("RGB")).unsqueeze(0) |
| with torch.no_grad(): |
| emb = model(tensor) |
| return F.normalize(emb.squeeze(0), p=2, dim=0) |
|
|
|
|
| def _extract_facenet_logits(img: Image.Image, model) -> torch.Tensor: |
| tensor = _facenet_transform()(img.convert("RGB")).unsqueeze(0) |
| with torch.no_grad(): |
| logits = model(tensor) |
| return logits.squeeze(0) |
|
|
|
|
| def select_best_images(pairs: list, n: int) -> list: |
| """pairs: list of (image, seed). Returns top-n (image, seed) pairs.""" |
| global facenet_model |
|
|
| n = min(n, len(pairs)) |
| images = [p[0] for p in pairs] |
| if facenet_model is None: |
| return pairs[:n] |
|
|
| embeddings = torch.stack([_extract_facenet_emb(img, facenet_model) for img in images]) |
| sim_matrix = F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2) |
| avg_sims = (sim_matrix.sum(dim=1) - 1) / (len(images) - 1) |
| top_indices = avg_sims.argsort(descending=True)[:n].tolist() |
| print(f"[select_best:pairwise] top {n} indices={top_indices} avg_sims={avg_sims[top_indices].tolist()}") |
| return [pairs[i] for i in top_indices] |
|
|
|
|
| def select_best_images_combined(pairs: list, n: int) -> list: |
| """pairs: list of (image, seed). Returns top-n (image, seed) pairs.""" |
| global mtcnn_model, facenet_classify_model |
|
|
| n = min(n, len(pairs)) |
| if mtcnn_model is None or facenet_classify_model is None: |
| print("[select_best:combined] models unavailable, falling back to pairwise") |
| return select_best_images(pairs, n) |
|
|
| scores = [] |
| for idx, (img, _) in enumerate(pairs): |
| _, probs = mtcnn_model.detect(img) |
| det_conf = float(probs[0]) if probs is not None and probs[0] is not None else 0.0 |
|
|
| tensor = _facenet_transform()(img.convert("RGB")).unsqueeze(0) |
| with torch.no_grad(): |
| logits = facenet_classify_model(tensor) |
| classify_conf = float(F.softmax(logits, dim=1).max(dim=1).values[0]) |
|
|
| combined = det_conf * classify_conf |
| scores.append(combined) |
| print(f" [combined] idx={idx} det={det_conf:.3f} classify={classify_conf:.3f} combined={combined:.3f}") |
|
|
| top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n] |
| print(f"[select_best:combined] top {n} indices={top_indices} scores={[scores[i] for i in top_indices]}") |
| return [pairs[i] for i in top_indices] |
|
|
|
|
| SELECTION_METHODS = ["Pairwise similarity", "Detection + Classify confidence"] |
| DEFAULT_SELECTION_METHOD = SELECTION_METHODS[0] |
|
|
|
|
| |
| |
| |
| |
| |
| GENERATION_SEEDS = [42, 48, 56, 63, 74, 84, 86, 107, 119, 124, 125, 127, 128, 129] |
| INTERNAL_SAMPLES = len(GENERATION_SEEDS) |
| SHOW_SEED_CAPTIONS = False |
| RANDOM_SEED_SELECTION = True |
|
|
| @spaces.GPU(duration=120) |
| def generate(audio_path, num_display, guidance_scale, num_inference_steps, base_seed, selection_method=DEFAULT_SELECTION_METHOD): |
| global pipeline, speaker_encoder, facenet_model, device |
|
|
| if audio_path is None: |
| return None, "Please provide an audio file." |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"[generate] device = {device}") |
|
|
| if pipeline is None or speaker_encoder is None: |
| print("[generate] Loading models lazily...") |
| load_models() |
| print("[generate] Models loaded.") |
|
|
| if pipeline is None or speaker_encoder is None: |
| return None, "Model loading failed. Check logs." |
|
|
| try: |
| waveform = load_and_process_audio(audio_path, device, max_seconds=5.0) |
| except Exception as e: |
| return None, f"Audio loading failed: {e}" |
|
|
| dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
| with torch.no_grad(): |
| speech_z = speaker_encoder( |
| waveform, |
| normalize=True, |
| apply_shared_projection=False, |
| ) |
|
|
| id_emb = speech_z.to(dtype) |
| id_emb_projected = project_face_embs(pipeline, id_emb) |
|
|
| n = int(num_display) |
| seeds_to_run = ( |
| np.random.choice(GENERATION_SEEDS, size=min(n, len(GENERATION_SEEDS)), replace=False).tolist() |
| if RANDOM_SEED_SELECTION else GENERATION_SEEDS |
| ) |
| print(f"[generate] seeds_to_run={seeds_to_run}") |
|
|
| pairs = [] |
| for seed in seeds_to_run: |
| generator = torch.Generator(device=device).manual_seed(seed) |
|
|
| img = pipeline( |
| prompt_embeds=id_emb_projected, |
| num_inference_steps=int(num_inference_steps), |
| guidance_scale=float(guidance_scale), |
| num_images_per_prompt=1, |
| generator=generator, |
| ).images[0] |
|
|
| pairs.append((img, seed)) |
|
|
| if RANDOM_SEED_SELECTION: |
| best = pairs |
| elif selection_method == "Detection + Classify confidence": |
| best = select_best_images_combined(pairs, int(num_display)) |
| else: |
| best = select_best_images(pairs, int(num_display)) |
|
|
| return [(img, f"Seed: {seed}") if SHOW_SEED_CAPTIONS else img for img, seed in best], "" |
|
|
| |
| |
| |
|
|
| def load_models(): |
| global pipeline, speaker_encoder, facenet_model, facenet_classify_model, mtcnn_model, device |
| dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device}") |
|
|
| |
| print("Loading speaker encoder...") |
| if SKIP_SPEAKER_ENCODER: |
| speaker_encoder = RawWavLMEncoder("microsoft/wavlm-base-sv", device) |
| print(" Using raw WavLM x-vector encoder (no fine-tuned checkpoint)") |
| else: |
| enc = SpeechFaceXVectorEncoder( |
| pretrained_path="microsoft/wavlm-base-sv", |
| face_emb_dim=512, |
| dropout=0.0, |
| use_projection=True, |
| freeze_feature_encoder=True, |
| ) |
| encoder_pt = hf_hub_download(CHECKPOINT_REPO, ENCODER_FILENAME) |
| ckpt = torch.load(encoder_pt, map_location=device, weights_only=False) |
| enc.load_state_dict(ckpt["model"], strict=False) |
| speaker_encoder = enc.to(device).eval() |
| print(f" Loaded from {CHECKPOINT_REPO}/{ENCODER_FILENAME}") |
|
|
| |
| print("Loading diffusion pipeline...") |
| if SKIP_LORA: |
| encoder = CLIPTextModelWrapper.from_pretrained(ARC2FACE_REPO, subfolder='encoder', torch_dtype=dtype) |
| unet = UNet2DConditionModel.from_pretrained(ARC2FACE_REPO, subfolder='arc2face', torch_dtype=dtype) |
| print(" Using base Arc2Face (no LoRA)") |
| else: |
| checkpoint_dir = snapshot_download(CHECKPOINT_REPO) |
| checkpoint = resolve_checkpoint_path(checkpoint_dir) |
| print(f" Checkpoint: {checkpoint}") |
| encoder = load_encoder_with_lora(checkpoint).to(dtype=dtype) |
| unet = load_unet_with_lora(checkpoint).to(dtype=dtype) |
|
|
| pipeline = StableDiffusionPipeline.from_pretrained( |
| BASE_MODEL, |
| text_encoder=encoder, |
| unet=unet, |
| torch_dtype=dtype, |
| safety_checker=None, |
| ) |
| pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
| pipeline = pipeline.to(device) |
| print(" Pipeline ready") |
|
|
| |
| print("Loading FaceNet + MTCNN for best-sample selection...") |
| try: |
| from facenet_pytorch import InceptionResnetV1, MTCNN |
| facenet_model = InceptionResnetV1(pretrained='vggface2', classify=False).eval() |
| facenet_classify_model = InceptionResnetV1(pretrained='vggface2', classify=True).eval() |
| mtcnn_model = MTCNN(keep_all=False, device='cpu') |
| print(" FaceNet + MTCNN ready") |
| except Exception as e: |
| print(f" FaceNet/MTCNN unavailable ({e}); select-best will fall back to first image") |
| facenet_model = None |
| facenet_classify_model = None |
| mtcnn_model = None |
|
|
|
|
| |
| |
| |
|
|
| def build_demo(): |
| with gr.Blocks(title="AAS2F: Ambiguity-Aware Speech-to-Face Synthesis with Speaker-Conditioned Diffusion Models") as demo: |
| gr.Markdown("# AAS2F: Ambiguity-Aware Speech-to-Face Synthesis with Speaker-Conditioned Diffusion Models") |
| gr.Markdown( |
| "**Steps to use the demo:**\n\n" |
| "1. Upload or record a speech audio clip. **Please provide at least 5 seconds of speech.**\n" |
| "2. Note that it works best with **English**, but should work with other languages as well.\n" |
| "3. After you are done recording/uploading the audio, click the 'Generate' button to start the generation process.\n" |
| "4. After a few seconds, the generated images will be displayed on the right." |
| ) |
|
|
| DEFAULT_NUM_DISPLAY = 3 |
| DEFAULT_GUIDANCE_SCALE = 2.5 |
| DEFAULT_NUM_STEPS = 50 |
| DEFAULT_BASE_SEED = 42 |
|
|
| with gr.Row(): |
| with gr.Column(): |
| with gr.Row(): |
| audio_upload = gr.Audio( |
| sources=["upload"], |
| type="filepath", |
| label="Upload Audio", |
| ) |
| audio_mic = gr.Audio( |
| sources=["microphone"], |
| type="filepath", |
| label="Record Audio", |
| ) |
| generate_btn = gr.Button("Generate", variant="primary", interactive=False) |
|
|
| with gr.Column(): |
| gallery = gr.Gallery(label="Generated Images") |
| status = gr.HTML(visible=False) |
|
|
| def _update_btn(upload, mic): |
| return gr.update(interactive=(upload is not None or mic is not None)) |
|
|
| audio_upload.change(fn=_update_btn, inputs=[audio_upload, audio_mic], outputs=generate_btn) |
| audio_mic.change(fn=_update_btn, inputs=[audio_upload, audio_mic], outputs=generate_btn) |
|
|
| def _generate(upload, mic): |
| audio = upload if upload is not None else mic |
| imgs, msg = generate(audio, DEFAULT_NUM_DISPLAY, DEFAULT_GUIDANCE_SCALE, DEFAULT_NUM_STEPS, DEFAULT_BASE_SEED) |
| if msg: |
| error_html = f'<div style="background:#fee2e2;border:1px solid #f87171;border-radius:8px;padding:12px 16px;color:#b91c1c;font-size:0.95em;">⚠️ {msg}</div>' |
| return imgs, gr.update(value=error_html, visible=True) |
| return imgs, gr.update(value="", visible=False) |
|
|
| generate_btn.click( |
| fn=lambda u, m: (gr.update(value="Generating...", interactive=False, variant="secondary"), gr.update(value="", visible=False)), |
| inputs=[audio_upload, audio_mic], |
| outputs=[generate_btn, status], |
| ).then( |
| fn=_generate, |
| inputs=[audio_upload, audio_mic], |
| outputs=[gallery, status], |
| ).then( |
| fn=lambda u, m: gr.update(value="Generate", interactive=(u is not None or m is not None), variant="primary"), |
| inputs=[audio_upload, audio_mic], |
| outputs=generate_btn, |
| ) |
|
|
| return demo |
|
|
|
|
| |
| |
| |
|
|
| demo = build_demo() |
| demo.queue() |
| demo.launch() |
|
|