File size: 6,369 Bytes
a229ab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f2dc3f
a229ab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f2dc3f
 
a229ab6
 
 
 
 
 
1f2dc3f
 
 
 
a229ab6
 
 
 
1f2dc3f
 
 
a229ab6
 
 
 
 
 
 
 
 
1f2dc3f
a229ab6
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
AudioGAN Hugging Face Space - Text-to-Audio Generation (CPU only)
GitHub: https://github.com/SeaSky1027/AudioGAN
Hugging Face Model: https://huggingface.co/SeaSky1027/AudioGAN
"""
import os
import sys
import warnings
import logging
from pathlib import Path
from datetime import datetime

warnings.filterwarnings("ignore")

import torch
import torchaudio
import gradio as gr
from huggingface_hub import snapshot_download

# CPU only for Space
device = "cpu"

# Output directory for generated audio
OUTPUT_DIR = Path("./output/gradio")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Checkpoint directory (Hugging Face model weights)
CKPT_DIR = Path("./pretrained_models")
CKPT_DIR.mkdir(parents=True, exist_ok=True)

log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

# Global model cache
MODEL_CACHE = None


def ensure_models_downloaded():
    """Download AudioGAN weights from Hugging Face if not present."""
    required_files = [
        "generator.pt",
        "hifigan_16k_64bins.json",
        "hifigan_16k_64bins.ckpt",
    ]
    need_download = False
    for f in required_files:
        if not (CKPT_DIR / f).exists():
            need_download = True
            break
    if need_download:
        log.info("Downloading AudioGAN model from Hugging Face...")
        snapshot_download(repo_id="SeaSky1027/AudioGAN", local_dir=str(CKPT_DIR))
    log.info(f"Checkpoint directory: {CKPT_DIR}")


def load_model_for_inference(ckpt_path):
    """Load CLAP, Generator, HiFiGAN from checkpoint directory (same as infer.py)."""
    import CLAP
    from generator import Generator
    from HiFiGAN.inference import get_vocoder

    ckpt_path = Path(ckpt_path)
    model_dict = {}

    model_dict["clap_encoder"] = CLAP.CLAP_Module(amodel="HTSAT-base", tmodel="roberta").eval()
    model_dict["clap_encoder"].load_ckpt(str(ckpt_path), "music_speech_audioset_epoch_15_esc_89.98.pt")

    model_dict["generator"] = Generator()
    ckpt_state = torch.load(ckpt_path / "generator.pt", weights_only=False, map_location="cpu")
    model_dict["generator"].load_state_dict(ckpt_state)

    model_dict["vocoder"] = get_vocoder(sr=16000, ckpt_path=str(ckpt_path)).eval()

    model_dict["clap_tester"] = CLAP.CLAP_Module(amodel="HTSAT-tiny", tmodel="roberta").eval()
    model_dict["clap_tester"].load_ckpt(str(ckpt_path), "630k-audioset-best.pt")

    for name in model_dict:
        model_dict[name].eval()

    return model_dict


def get_text_embedding(text, model_dict):
    """Get CLAP text embedding."""
    with torch.no_grad():
        sentence_embedding, word_embedding, sequence_lengths = model_dict["clap_encoder"].get_text_embedding(text)
    return sentence_embedding.detach(), word_embedding.detach(), sequence_lengths


def load_model_cache():
    """Load and cache the model (CPU)."""
    global MODEL_CACHE
    if MODEL_CACHE is not None:
        return MODEL_CACHE
    log.info("Loading AudioGAN model (CPU)...")
    MODEL_CACHE = load_model_for_inference(CKPT_DIR)
    # Keep on CPU
    for name in MODEL_CACHE:
        MODEL_CACHE[name] = MODEL_CACHE[name].to(device)
    log.info("Model loaded.")
    return MODEL_CACHE


@torch.inference_mode()
def generate_audio_gradio(prompt, seed):
    """Generate audio from text prompt with fixed seed (CPU)."""
    if not prompt or not prompt.strip():
        raise gr.Error("Enter the prompt.")

    model_dict = load_model_cache()

    # Normalize text (same as infer.py)
    text = prompt.strip()
    if text.endswith("."):
        text = text[:-1]
    if text and text[0].islower():
        text = text[0].upper() + text[1:]

    text_list = [text]

    # Reproducible noise with seed
    generator = torch.Generator(device=device)
    generator.manual_seed(int(seed))

    sentence_embedding, word_embedding, sequence_lengths = get_text_embedding(text_list, model_dict)
    sentence_embedding = sentence_embedding.to(device)
    word_embedding = word_embedding.to(device)

    noise = torch.randn((1, 128), device=device, generator=generator)
    fake_mel = model_dict["generator"](noise, sentence_embedding, word_embedding, sequence_lengths)
    fake_sound = model_dict["vocoder"](fake_mel.squeeze())

    if fake_sound.dim() == 3:
        fake_sound = fake_sound.squeeze(1)
    fake_sound = fake_sound.detach().cpu()

    # Save to file
    safe_prompt = "".join(c for c in prompt if c.isalnum() or c in (" ", "_")).rstrip().replace(" ", "_")[:50]
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"{safe_prompt}_{timestamp}.wav"
    save_path = OUTPUT_DIR / filename
    torchaudio.save(str(save_path), fake_sound, 16000)
    log.info(f"Saved: {save_path}")

    return str(save_path), prompt


# UI: Prompt, Audio Sample, Seed, Examples only
input_text = gr.Textbox(lines=2, label="Prompt", placeholder="e.g. A bird is chirping in a quiet place.")
seed = gr.Number(value=42, label="Seed", minimum=0, maximum=2**32 - 1, step=1, precision=0)

description_text = """
###
AudioGAN is a novel GAN-based model tailored for compact and efficient text-to-audio generation.

- [πŸ“– Paper (arXiv)](https://arxiv.org/abs/2512.22166)  
- [πŸ’» GitHub](https://github.com/SeaSky1027/AudioGAN)  
- [πŸ€— Hugging Face Model](https://huggingface.co/SeaSky1027/AudioGAN)  
- [🌐 Project Page](https://seasky1027.github.io/AudioGAN/)

This space uses only the CPU.
So, model inference may be slow.
Download the code from GitHub and the weights from huggingface.
Then, run inference on your own GPU for faster results.
"""

# Examples: Prompt and Seed only (no duration, steps, variant, etc.)
examples_list = [
    ["Chopping meat on a wooden table.", 10],
    ["A bird is chirping in a quiet place.", 27],
    ["Melodic human whistling harmonizing with natural birdsong", 1027],
]

gr_interface = gr.Interface(
    fn=generate_audio_gradio,
    inputs=[input_text, seed],
    outputs=[
        gr.Audio(label="🎡 Audio Sample", type="filepath"),
        gr.Textbox(label="Prompt Used", interactive=False),
    ],
    title="AudioGAN: A Compact and Efficient Framework for Real-Time High-Fidelity Text-to-Audio Generation",
    description=description_text,
    flagging_mode="never",
    examples=examples_list,
    cache_examples=False,
)

if __name__ == "__main__":
    ensure_models_downloaded()
    load_model_cache()
    gr_interface.queue(max_size=10).launch()