AudioGAN / app.py
SeaSky1027's picture
Modify requirements.txt
1f2dc3f
"""
AudioGAN Hugging Face Space - Text-to-Audio Generation (CPU only)
GitHub: https://github.com/SeaSky1027/AudioGAN
Hugging Face Model: https://huggingface.co/SeaSky1027/AudioGAN
"""
import os
import sys
import warnings
import logging
from pathlib import Path
from datetime import datetime
warnings.filterwarnings("ignore")
import torch
import torchaudio
import gradio as gr
from huggingface_hub import snapshot_download
# CPU only for Space
device = "cpu"
# Output directory for generated audio
OUTPUT_DIR = Path("./output/gradio")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# Checkpoint directory (Hugging Face model weights)
CKPT_DIR = Path("./pretrained_models")
CKPT_DIR.mkdir(parents=True, exist_ok=True)
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# Global model cache
MODEL_CACHE = None
def ensure_models_downloaded():
"""Download AudioGAN weights from Hugging Face if not present."""
required_files = [
"generator.pt",
"hifigan_16k_64bins.json",
"hifigan_16k_64bins.ckpt",
]
need_download = False
for f in required_files:
if not (CKPT_DIR / f).exists():
need_download = True
break
if need_download:
log.info("Downloading AudioGAN model from Hugging Face...")
snapshot_download(repo_id="SeaSky1027/AudioGAN", local_dir=str(CKPT_DIR))
log.info(f"Checkpoint directory: {CKPT_DIR}")
def load_model_for_inference(ckpt_path):
"""Load CLAP, Generator, HiFiGAN from checkpoint directory (same as infer.py)."""
import CLAP
from generator import Generator
from HiFiGAN.inference import get_vocoder
ckpt_path = Path(ckpt_path)
model_dict = {}
model_dict["clap_encoder"] = CLAP.CLAP_Module(amodel="HTSAT-base", tmodel="roberta").eval()
model_dict["clap_encoder"].load_ckpt(str(ckpt_path), "music_speech_audioset_epoch_15_esc_89.98.pt")
model_dict["generator"] = Generator()
ckpt_state = torch.load(ckpt_path / "generator.pt", weights_only=False, map_location="cpu")
model_dict["generator"].load_state_dict(ckpt_state)
model_dict["vocoder"] = get_vocoder(sr=16000, ckpt_path=str(ckpt_path)).eval()
model_dict["clap_tester"] = CLAP.CLAP_Module(amodel="HTSAT-tiny", tmodel="roberta").eval()
model_dict["clap_tester"].load_ckpt(str(ckpt_path), "630k-audioset-best.pt")
for name in model_dict:
model_dict[name].eval()
return model_dict
def get_text_embedding(text, model_dict):
"""Get CLAP text embedding."""
with torch.no_grad():
sentence_embedding, word_embedding, sequence_lengths = model_dict["clap_encoder"].get_text_embedding(text)
return sentence_embedding.detach(), word_embedding.detach(), sequence_lengths
def load_model_cache():
"""Load and cache the model (CPU)."""
global MODEL_CACHE
if MODEL_CACHE is not None:
return MODEL_CACHE
log.info("Loading AudioGAN model (CPU)...")
MODEL_CACHE = load_model_for_inference(CKPT_DIR)
# Keep on CPU
for name in MODEL_CACHE:
MODEL_CACHE[name] = MODEL_CACHE[name].to(device)
log.info("Model loaded.")
return MODEL_CACHE
@torch.inference_mode()
def generate_audio_gradio(prompt, seed):
"""Generate audio from text prompt with fixed seed (CPU)."""
if not prompt or not prompt.strip():
raise gr.Error("Enter the prompt.")
model_dict = load_model_cache()
# Normalize text (same as infer.py)
text = prompt.strip()
if text.endswith("."):
text = text[:-1]
if text and text[0].islower():
text = text[0].upper() + text[1:]
text_list = [text]
# Reproducible noise with seed
generator = torch.Generator(device=device)
generator.manual_seed(int(seed))
sentence_embedding, word_embedding, sequence_lengths = get_text_embedding(text_list, model_dict)
sentence_embedding = sentence_embedding.to(device)
word_embedding = word_embedding.to(device)
noise = torch.randn((1, 128), device=device, generator=generator)
fake_mel = model_dict["generator"](noise, sentence_embedding, word_embedding, sequence_lengths)
fake_sound = model_dict["vocoder"](fake_mel.squeeze())
if fake_sound.dim() == 3:
fake_sound = fake_sound.squeeze(1)
fake_sound = fake_sound.detach().cpu()
# Save to file
safe_prompt = "".join(c for c in prompt if c.isalnum() or c in (" ", "_")).rstrip().replace(" ", "_")[:50]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{safe_prompt}_{timestamp}.wav"
save_path = OUTPUT_DIR / filename
torchaudio.save(str(save_path), fake_sound, 16000)
log.info(f"Saved: {save_path}")
return str(save_path), prompt
# UI: Prompt, Audio Sample, Seed, Examples only
input_text = gr.Textbox(lines=2, label="Prompt", placeholder="e.g. A bird is chirping in a quiet place.")
seed = gr.Number(value=42, label="Seed", minimum=0, maximum=2**32 - 1, step=1, precision=0)
description_text = """
###
AudioGAN is a novel GAN-based model tailored for compact and efficient text-to-audio generation.
- [πŸ“– Paper (arXiv)](https://arxiv.org/abs/2512.22166)
- [πŸ’» GitHub](https://github.com/SeaSky1027/AudioGAN)
- [πŸ€— Hugging Face Model](https://huggingface.co/SeaSky1027/AudioGAN)
- [🌐 Project Page](https://seasky1027.github.io/AudioGAN/)
This space uses only the CPU.
So, model inference may be slow.
Download the code from GitHub and the weights from huggingface.
Then, run inference on your own GPU for faster results.
"""
# Examples: Prompt and Seed only (no duration, steps, variant, etc.)
examples_list = [
["Chopping meat on a wooden table.", 10],
["A bird is chirping in a quiet place.", 27],
["Melodic human whistling harmonizing with natural birdsong", 1027],
]
gr_interface = gr.Interface(
fn=generate_audio_gradio,
inputs=[input_text, seed],
outputs=[
gr.Audio(label="🎡 Audio Sample", type="filepath"),
gr.Textbox(label="Prompt Used", interactive=False),
],
title="AudioGAN: A Compact and Efficient Framework for Real-Time High-Fidelity Text-to-Audio Generation",
description=description_text,
flagging_mode="never",
examples=examples_list,
cache_examples=False,
)
if __name__ == "__main__":
ensure_models_downloaded()
load_model_cache()
gr_interface.queue(max_size=10).launch()