""" AudioGAN Hugging Face Space - Text-to-Audio Generation (CPU only) GitHub: https://github.com/SeaSky1027/AudioGAN Hugging Face Model: https://huggingface.co/SeaSky1027/AudioGAN """ import os import sys import warnings import logging from pathlib import Path from datetime import datetime warnings.filterwarnings("ignore") import torch import torchaudio import gradio as gr from huggingface_hub import snapshot_download # CPU only for Space device = "cpu" # Output directory for generated audio OUTPUT_DIR = Path("./output/gradio") OUTPUT_DIR.mkdir(parents=True, exist_ok=True) # Checkpoint directory (Hugging Face model weights) CKPT_DIR = Path("./pretrained_models") CKPT_DIR.mkdir(parents=True, exist_ok=True) log = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # Global model cache MODEL_CACHE = None def ensure_models_downloaded(): """Download AudioGAN weights from Hugging Face if not present.""" required_files = [ "generator.pt", "hifigan_16k_64bins.json", "hifigan_16k_64bins.ckpt", ] need_download = False for f in required_files: if not (CKPT_DIR / f).exists(): need_download = True break if need_download: log.info("Downloading AudioGAN model from Hugging Face...") snapshot_download(repo_id="SeaSky1027/AudioGAN", local_dir=str(CKPT_DIR)) log.info(f"Checkpoint directory: {CKPT_DIR}") def load_model_for_inference(ckpt_path): """Load CLAP, Generator, HiFiGAN from checkpoint directory (same as infer.py).""" import CLAP from generator import Generator from HiFiGAN.inference import get_vocoder ckpt_path = Path(ckpt_path) model_dict = {} model_dict["clap_encoder"] = CLAP.CLAP_Module(amodel="HTSAT-base", tmodel="roberta").eval() model_dict["clap_encoder"].load_ckpt(str(ckpt_path), "music_speech_audioset_epoch_15_esc_89.98.pt") model_dict["generator"] = Generator() ckpt_state = torch.load(ckpt_path / "generator.pt", weights_only=False, map_location="cpu") model_dict["generator"].load_state_dict(ckpt_state) model_dict["vocoder"] = get_vocoder(sr=16000, ckpt_path=str(ckpt_path)).eval() model_dict["clap_tester"] = CLAP.CLAP_Module(amodel="HTSAT-tiny", tmodel="roberta").eval() model_dict["clap_tester"].load_ckpt(str(ckpt_path), "630k-audioset-best.pt") for name in model_dict: model_dict[name].eval() return model_dict def get_text_embedding(text, model_dict): """Get CLAP text embedding.""" with torch.no_grad(): sentence_embedding, word_embedding, sequence_lengths = model_dict["clap_encoder"].get_text_embedding(text) return sentence_embedding.detach(), word_embedding.detach(), sequence_lengths def load_model_cache(): """Load and cache the model (CPU).""" global MODEL_CACHE if MODEL_CACHE is not None: return MODEL_CACHE log.info("Loading AudioGAN model (CPU)...") MODEL_CACHE = load_model_for_inference(CKPT_DIR) # Keep on CPU for name in MODEL_CACHE: MODEL_CACHE[name] = MODEL_CACHE[name].to(device) log.info("Model loaded.") return MODEL_CACHE @torch.inference_mode() def generate_audio_gradio(prompt, seed): """Generate audio from text prompt with fixed seed (CPU).""" if not prompt or not prompt.strip(): raise gr.Error("Enter the prompt.") model_dict = load_model_cache() # Normalize text (same as infer.py) text = prompt.strip() if text.endswith("."): text = text[:-1] if text and text[0].islower(): text = text[0].upper() + text[1:] text_list = [text] # Reproducible noise with seed generator = torch.Generator(device=device) generator.manual_seed(int(seed)) sentence_embedding, word_embedding, sequence_lengths = get_text_embedding(text_list, model_dict) sentence_embedding = sentence_embedding.to(device) word_embedding = word_embedding.to(device) noise = torch.randn((1, 128), device=device, generator=generator) fake_mel = model_dict["generator"](noise, sentence_embedding, word_embedding, sequence_lengths) fake_sound = model_dict["vocoder"](fake_mel.squeeze()) if fake_sound.dim() == 3: fake_sound = fake_sound.squeeze(1) fake_sound = fake_sound.detach().cpu() # Save to file safe_prompt = "".join(c for c in prompt if c.isalnum() or c in (" ", "_")).rstrip().replace(" ", "_")[:50] timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{safe_prompt}_{timestamp}.wav" save_path = OUTPUT_DIR / filename torchaudio.save(str(save_path), fake_sound, 16000) log.info(f"Saved: {save_path}") return str(save_path), prompt # UI: Prompt, Audio Sample, Seed, Examples only input_text = gr.Textbox(lines=2, label="Prompt", placeholder="e.g. A bird is chirping in a quiet place.") seed = gr.Number(value=42, label="Seed", minimum=0, maximum=2**32 - 1, step=1, precision=0) description_text = """ ### AudioGAN is a novel GAN-based model tailored for compact and efficient text-to-audio generation. - [📖 Paper (arXiv)](https://arxiv.org/abs/2512.22166) - [💻 GitHub](https://github.com/SeaSky1027/AudioGAN) - [🤗 Hugging Face Model](https://huggingface.co/SeaSky1027/AudioGAN) - [🌐 Project Page](https://seasky1027.github.io/AudioGAN/) This space uses only the CPU. So, model inference may be slow. Download the code from GitHub and the weights from huggingface. Then, run inference on your own GPU for faster results. """ # Examples: Prompt and Seed only (no duration, steps, variant, etc.) examples_list = [ ["Chopping meat on a wooden table.", 10], ["A bird is chirping in a quiet place.", 27], ["Melodic human whistling harmonizing with natural birdsong", 1027], ] gr_interface = gr.Interface( fn=generate_audio_gradio, inputs=[input_text, seed], outputs=[ gr.Audio(label="🎵 Audio Sample", type="filepath"), gr.Textbox(label="Prompt Used", interactive=False), ], title="AudioGAN: A Compact and Efficient Framework for Real-Time High-Fidelity Text-to-Audio Generation", description=description_text, flagging_mode="never", examples=examples_list, cache_examples=False, ) if __name__ == "__main__": ensure_models_downloaded() load_model_cache() gr_interface.queue(max_size=10).launch()