Spaces:
Sleeping
Sleeping
| """ | |
| AudioGAN Hugging Face Space - Text-to-Audio Generation (CPU only) | |
| GitHub: https://github.com/SeaSky1027/AudioGAN | |
| Hugging Face Model: https://huggingface.co/SeaSky1027/AudioGAN | |
| """ | |
| import os | |
| import sys | |
| import warnings | |
| import logging | |
| from pathlib import Path | |
| from datetime import datetime | |
| warnings.filterwarnings("ignore") | |
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| # CPU only for Space | |
| device = "cpu" | |
| # Output directory for generated audio | |
| OUTPUT_DIR = Path("./output/gradio") | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| # Checkpoint directory (Hugging Face model weights) | |
| CKPT_DIR = Path("./pretrained_models") | |
| CKPT_DIR.mkdir(parents=True, exist_ok=True) | |
| log = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| # Global model cache | |
| MODEL_CACHE = None | |
| def ensure_models_downloaded(): | |
| """Download AudioGAN weights from Hugging Face if not present.""" | |
| required_files = [ | |
| "generator.pt", | |
| "hifigan_16k_64bins.json", | |
| "hifigan_16k_64bins.ckpt", | |
| ] | |
| need_download = False | |
| for f in required_files: | |
| if not (CKPT_DIR / f).exists(): | |
| need_download = True | |
| break | |
| if need_download: | |
| log.info("Downloading AudioGAN model from Hugging Face...") | |
| snapshot_download(repo_id="SeaSky1027/AudioGAN", local_dir=str(CKPT_DIR)) | |
| log.info(f"Checkpoint directory: {CKPT_DIR}") | |
| def load_model_for_inference(ckpt_path): | |
| """Load CLAP, Generator, HiFiGAN from checkpoint directory (same as infer.py).""" | |
| import CLAP | |
| from generator import Generator | |
| from HiFiGAN.inference import get_vocoder | |
| ckpt_path = Path(ckpt_path) | |
| model_dict = {} | |
| model_dict["clap_encoder"] = CLAP.CLAP_Module(amodel="HTSAT-base", tmodel="roberta").eval() | |
| model_dict["clap_encoder"].load_ckpt(str(ckpt_path), "music_speech_audioset_epoch_15_esc_89.98.pt") | |
| model_dict["generator"] = Generator() | |
| ckpt_state = torch.load(ckpt_path / "generator.pt", weights_only=False, map_location="cpu") | |
| model_dict["generator"].load_state_dict(ckpt_state) | |
| model_dict["vocoder"] = get_vocoder(sr=16000, ckpt_path=str(ckpt_path)).eval() | |
| model_dict["clap_tester"] = CLAP.CLAP_Module(amodel="HTSAT-tiny", tmodel="roberta").eval() | |
| model_dict["clap_tester"].load_ckpt(str(ckpt_path), "630k-audioset-best.pt") | |
| for name in model_dict: | |
| model_dict[name].eval() | |
| return model_dict | |
| def get_text_embedding(text, model_dict): | |
| """Get CLAP text embedding.""" | |
| with torch.no_grad(): | |
| sentence_embedding, word_embedding, sequence_lengths = model_dict["clap_encoder"].get_text_embedding(text) | |
| return sentence_embedding.detach(), word_embedding.detach(), sequence_lengths | |
| def load_model_cache(): | |
| """Load and cache the model (CPU).""" | |
| global MODEL_CACHE | |
| if MODEL_CACHE is not None: | |
| return MODEL_CACHE | |
| log.info("Loading AudioGAN model (CPU)...") | |
| MODEL_CACHE = load_model_for_inference(CKPT_DIR) | |
| # Keep on CPU | |
| for name in MODEL_CACHE: | |
| MODEL_CACHE[name] = MODEL_CACHE[name].to(device) | |
| log.info("Model loaded.") | |
| return MODEL_CACHE | |
| def generate_audio_gradio(prompt, seed): | |
| """Generate audio from text prompt with fixed seed (CPU).""" | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Enter the prompt.") | |
| model_dict = load_model_cache() | |
| # Normalize text (same as infer.py) | |
| text = prompt.strip() | |
| if text.endswith("."): | |
| text = text[:-1] | |
| if text and text[0].islower(): | |
| text = text[0].upper() + text[1:] | |
| text_list = [text] | |
| # Reproducible noise with seed | |
| generator = torch.Generator(device=device) | |
| generator.manual_seed(int(seed)) | |
| sentence_embedding, word_embedding, sequence_lengths = get_text_embedding(text_list, model_dict) | |
| sentence_embedding = sentence_embedding.to(device) | |
| word_embedding = word_embedding.to(device) | |
| noise = torch.randn((1, 128), device=device, generator=generator) | |
| fake_mel = model_dict["generator"](noise, sentence_embedding, word_embedding, sequence_lengths) | |
| fake_sound = model_dict["vocoder"](fake_mel.squeeze()) | |
| if fake_sound.dim() == 3: | |
| fake_sound = fake_sound.squeeze(1) | |
| fake_sound = fake_sound.detach().cpu() | |
| # Save to file | |
| safe_prompt = "".join(c for c in prompt if c.isalnum() or c in (" ", "_")).rstrip().replace(" ", "_")[:50] | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"{safe_prompt}_{timestamp}.wav" | |
| save_path = OUTPUT_DIR / filename | |
| torchaudio.save(str(save_path), fake_sound, 16000) | |
| log.info(f"Saved: {save_path}") | |
| return str(save_path), prompt | |
| # UI: Prompt, Audio Sample, Seed, Examples only | |
| input_text = gr.Textbox(lines=2, label="Prompt", placeholder="e.g. A bird is chirping in a quiet place.") | |
| seed = gr.Number(value=42, label="Seed", minimum=0, maximum=2**32 - 1, step=1, precision=0) | |
| description_text = """ | |
| ### | |
| AudioGAN is a novel GAN-based model tailored for compact and efficient text-to-audio generation. | |
| - [π Paper (arXiv)](https://arxiv.org/abs/2512.22166) | |
| - [π» GitHub](https://github.com/SeaSky1027/AudioGAN) | |
| - [π€ Hugging Face Model](https://huggingface.co/SeaSky1027/AudioGAN) | |
| - [π Project Page](https://seasky1027.github.io/AudioGAN/) | |
| This space uses only the CPU. | |
| So, model inference may be slow. | |
| Download the code from GitHub and the weights from huggingface. | |
| Then, run inference on your own GPU for faster results. | |
| """ | |
| # Examples: Prompt and Seed only (no duration, steps, variant, etc.) | |
| examples_list = [ | |
| ["Chopping meat on a wooden table.", 10], | |
| ["A bird is chirping in a quiet place.", 27], | |
| ["Melodic human whistling harmonizing with natural birdsong", 1027], | |
| ] | |
| gr_interface = gr.Interface( | |
| fn=generate_audio_gradio, | |
| inputs=[input_text, seed], | |
| outputs=[ | |
| gr.Audio(label="π΅ Audio Sample", type="filepath"), | |
| gr.Textbox(label="Prompt Used", interactive=False), | |
| ], | |
| title="AudioGAN: A Compact and Efficient Framework for Real-Time High-Fidelity Text-to-Audio Generation", | |
| description=description_text, | |
| flagging_mode="never", | |
| examples=examples_list, | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| ensure_models_downloaded() | |
| load_model_cache() | |
| gr_interface.queue(max_size=10).launch() | |