Spaces:
Build error
Build error
File size: 6,369 Bytes
a229ab6 1f2dc3f a229ab6 1f2dc3f a229ab6 1f2dc3f a229ab6 1f2dc3f a229ab6 1f2dc3f a229ab6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | """
AudioGAN Hugging Face Space - Text-to-Audio Generation (CPU only)
GitHub: https://github.com/SeaSky1027/AudioGAN
Hugging Face Model: https://huggingface.co/SeaSky1027/AudioGAN
"""
import os
import sys
import warnings
import logging
from pathlib import Path
from datetime import datetime
warnings.filterwarnings("ignore")
import torch
import torchaudio
import gradio as gr
from huggingface_hub import snapshot_download
# CPU only for Space
device = "cpu"
# Output directory for generated audio
OUTPUT_DIR = Path("./output/gradio")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# Checkpoint directory (Hugging Face model weights)
CKPT_DIR = Path("./pretrained_models")
CKPT_DIR.mkdir(parents=True, exist_ok=True)
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# Global model cache
MODEL_CACHE = None
def ensure_models_downloaded():
"""Download AudioGAN weights from Hugging Face if not present."""
required_files = [
"generator.pt",
"hifigan_16k_64bins.json",
"hifigan_16k_64bins.ckpt",
]
need_download = False
for f in required_files:
if not (CKPT_DIR / f).exists():
need_download = True
break
if need_download:
log.info("Downloading AudioGAN model from Hugging Face...")
snapshot_download(repo_id="SeaSky1027/AudioGAN", local_dir=str(CKPT_DIR))
log.info(f"Checkpoint directory: {CKPT_DIR}")
def load_model_for_inference(ckpt_path):
"""Load CLAP, Generator, HiFiGAN from checkpoint directory (same as infer.py)."""
import CLAP
from generator import Generator
from HiFiGAN.inference import get_vocoder
ckpt_path = Path(ckpt_path)
model_dict = {}
model_dict["clap_encoder"] = CLAP.CLAP_Module(amodel="HTSAT-base", tmodel="roberta").eval()
model_dict["clap_encoder"].load_ckpt(str(ckpt_path), "music_speech_audioset_epoch_15_esc_89.98.pt")
model_dict["generator"] = Generator()
ckpt_state = torch.load(ckpt_path / "generator.pt", weights_only=False, map_location="cpu")
model_dict["generator"].load_state_dict(ckpt_state)
model_dict["vocoder"] = get_vocoder(sr=16000, ckpt_path=str(ckpt_path)).eval()
model_dict["clap_tester"] = CLAP.CLAP_Module(amodel="HTSAT-tiny", tmodel="roberta").eval()
model_dict["clap_tester"].load_ckpt(str(ckpt_path), "630k-audioset-best.pt")
for name in model_dict:
model_dict[name].eval()
return model_dict
def get_text_embedding(text, model_dict):
"""Get CLAP text embedding."""
with torch.no_grad():
sentence_embedding, word_embedding, sequence_lengths = model_dict["clap_encoder"].get_text_embedding(text)
return sentence_embedding.detach(), word_embedding.detach(), sequence_lengths
def load_model_cache():
"""Load and cache the model (CPU)."""
global MODEL_CACHE
if MODEL_CACHE is not None:
return MODEL_CACHE
log.info("Loading AudioGAN model (CPU)...")
MODEL_CACHE = load_model_for_inference(CKPT_DIR)
# Keep on CPU
for name in MODEL_CACHE:
MODEL_CACHE[name] = MODEL_CACHE[name].to(device)
log.info("Model loaded.")
return MODEL_CACHE
@torch.inference_mode()
def generate_audio_gradio(prompt, seed):
"""Generate audio from text prompt with fixed seed (CPU)."""
if not prompt or not prompt.strip():
raise gr.Error("Enter the prompt.")
model_dict = load_model_cache()
# Normalize text (same as infer.py)
text = prompt.strip()
if text.endswith("."):
text = text[:-1]
if text and text[0].islower():
text = text[0].upper() + text[1:]
text_list = [text]
# Reproducible noise with seed
generator = torch.Generator(device=device)
generator.manual_seed(int(seed))
sentence_embedding, word_embedding, sequence_lengths = get_text_embedding(text_list, model_dict)
sentence_embedding = sentence_embedding.to(device)
word_embedding = word_embedding.to(device)
noise = torch.randn((1, 128), device=device, generator=generator)
fake_mel = model_dict["generator"](noise, sentence_embedding, word_embedding, sequence_lengths)
fake_sound = model_dict["vocoder"](fake_mel.squeeze())
if fake_sound.dim() == 3:
fake_sound = fake_sound.squeeze(1)
fake_sound = fake_sound.detach().cpu()
# Save to file
safe_prompt = "".join(c for c in prompt if c.isalnum() or c in (" ", "_")).rstrip().replace(" ", "_")[:50]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{safe_prompt}_{timestamp}.wav"
save_path = OUTPUT_DIR / filename
torchaudio.save(str(save_path), fake_sound, 16000)
log.info(f"Saved: {save_path}")
return str(save_path), prompt
# UI: Prompt, Audio Sample, Seed, Examples only
input_text = gr.Textbox(lines=2, label="Prompt", placeholder="e.g. A bird is chirping in a quiet place.")
seed = gr.Number(value=42, label="Seed", minimum=0, maximum=2**32 - 1, step=1, precision=0)
description_text = """
###
AudioGAN is a novel GAN-based model tailored for compact and efficient text-to-audio generation.
- [π Paper (arXiv)](https://arxiv.org/abs/2512.22166)
- [π» GitHub](https://github.com/SeaSky1027/AudioGAN)
- [π€ Hugging Face Model](https://huggingface.co/SeaSky1027/AudioGAN)
- [π Project Page](https://seasky1027.github.io/AudioGAN/)
This space uses only the CPU.
So, model inference may be slow.
Download the code from GitHub and the weights from huggingface.
Then, run inference on your own GPU for faster results.
"""
# Examples: Prompt and Seed only (no duration, steps, variant, etc.)
examples_list = [
["Chopping meat on a wooden table.", 10],
["A bird is chirping in a quiet place.", 27],
["Melodic human whistling harmonizing with natural birdsong", 1027],
]
gr_interface = gr.Interface(
fn=generate_audio_gradio,
inputs=[input_text, seed],
outputs=[
gr.Audio(label="π΅ Audio Sample", type="filepath"),
gr.Textbox(label="Prompt Used", interactive=False),
],
title="AudioGAN: A Compact and Efficient Framework for Real-Time High-Fidelity Text-to-Audio Generation",
description=description_text,
flagging_mode="never",
examples=examples_list,
cache_examples=False,
)
if __name__ == "__main__":
ensure_models_downloaded()
load_model_cache()
gr_interface.queue(max_size=10).launch()
|