File size: 12,274 Bytes
52cb40a
2a74df9
 
 
 
 
52cb40a
b419370
 
 
 
 
 
 
 
2a74df9
b419370
52cb40a
77b9459
 
 
b419370
52cb40a
2a74df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b419370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a74df9
b419370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21d591d
 
 
2a74df9
 
b419370
 
 
2a74df9
 
 
b419370
 
 
 
 
 
2a74df9
 
b419370
 
 
2a74df9
b419370
 
 
 
 
 
 
2a74df9
b419370
21d591d
b419370
 
 
b50b1c3
 
 
21d591d
2a74df9
b50b1c3
 
 
b419370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a74df9
21d591d
 
 
2a74df9
 
 
 
 
 
b419370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6634a6f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
import gradio as gr
try:
    import spaces
    require_gpu = spaces.GPU
except:
    require_gpu = lambda f: f
import torch
import numpy as np
import librosa
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import json5
import torchaudio
import tempfile
import os
import random
from audio_controlnet.infer import AudioControlNet

import logging
logging.getLogger("gradio").setLevel(logging.WARNING)

MAX_DURATION = 10.0  # seconds

# -----------------------------
# Random Examples Data
# -----------------------------
RANDOM_EXAMPLES = [
  {
    "caption": "People speak and clap, a child speaks and a camera clicks.",
    "events": {
      "Female speech, woman speaking": [[0.0, 3.969], [7.913, 8.157], [8.189, 9.654]],
      "Child speech, kid speaking": [[9.724, 10.0]]
    }
  },
  {
    "caption": "Background noise, tapping, and cat sounds are interspersed with purring.",
    "events": {
      "Cat": [[0.978, 2.291], [9.032, 10.0]]
    }
  },
  {
    "caption": "Water flows and dishes clatter with child speech and laughter.",
    "events": {
      "Child speech, kid speaking": [[0.0, 1.503], [1.732, 2.12], [2.942, 3.541], [7.803, 8.493]],
      "Dishes, pots, and pans": [[1.983, 2.156], [3.175, 3.298], [4.774, 5.076], [5.711, 5.834], [6.076, 6.24], [6.423, 7.012]],
      "Male speech, man speaking": [[8.547, 9.557]],
      "Water tap, faucet": [[0.0, 10.0]]
    }
  },
  {
    "caption": "Speech babble and clattering dishes and silverware can be heard, along with a child's voice.",
    "events": {
      "Dishes, pots, and pans": [[0.85, 0.969], [1.386, 1.504], [7.717, 7.874]],
      "Male speech, man speaking": [[0.748, 1.173]],
      "Cutlery, silverware": [[4.693, 4.843], [5.299, 5.52]],
      "Female speech, woman speaking": [[1.63, 3.409]],
      "Child speech, kid speaking": [[8.756, 9.354]]
    }
  },
  {
    "caption": "A man is speaking, with background sounds of wind and a river, and another man sighing and speaking.", 
    "events": {"Male speech, man speaking": [[0.0, 7.851], [8.903, 9.129], [9.328, 9.98]], "Conversation": [[0.0, 9.98]], "Wind": [[0.0, 9.98]], "Stream, river": [[0.0, 9.98]], "Sigh": [[8.157, 8.707]]}
  },
  {
    "caption": "Wind noise and cowbell are heard twice.", 
    "events": {"Wind noise (microphone)": [[0.0, 1.15], [2.378, 2.961]], "Cowbell": [[0.0, 10.0]]}
  },
  {
    "caption": "There are mechanisms, bird calls, clicking, and male speech.", 
    "events": {"Mechanisms": [[0.0, 10.0]], "Bird vocalization, bird call, bird song": [[1.122, 1.423]], "Clicking": [[1.139, 1.238], [4.737, 4.858]], "Male speech, man speaking": [[1.95, 2.875], [5.182, 5.795], [6.113, 6.807], [7.386, 8.138], [8.236, 8.803], [9.427, 10.0]]}
  },
  {
    "caption": "Propeller noise and a sound effect.", 
    "events": {"Propeller, airscrew": [[1.779, 10.0]], "Sound effect": [[1.811, 2.868]]}
  },
  {
    "caption": "Women converse and laugh in a noisy crowd.", 
    "events": {"Female speech, woman speaking": [[0.0, 1.669], [2.097, 2.976], [4.66, 8.98]], "Conversation": [[0.0, 9.379]], "Background noise": [[0.0, 9.379]], "Generic impact sounds": [[0.096, 0.318], [3.707, 3.944], [6.107, 6.314], [7.584, 7.695], [8.256, 8.367]], "Laughter": [[1.573, 2.947], [4.461, 6.174], [9.002, 9.364]], "Crowd": [[1.573, 2.954], [4.512, 6.129], [9.002, 9.379]], "Tick": [[1.691, 1.795], [4.276, 4.372]], "Sound effect": [[3.212, 4.416]]}
  }
]
def build_events_json_text(events):
    ret = ''
    for key,times in events.items():
        ret += f'    "{key}": {times},\n'
    ret = ret.strip(',')
    return '{\n'+ret+'}'

def generate_random_example():
    """Generate a random example with caption and sound events"""
    example = random.choice(RANDOM_EXAMPLES)
    events_json = build_events_json_text(example["events"])
    return example["caption"], events_json

# -----------------------------
# Feature extraction utilities
# -----------------------------
def process_audio_clip(audio):
    if audio is None:
        return None
    sr, y = audio
    y = y.astype(np.float32)
    num_samples = int(MAX_DURATION * sr)
    if y.shape[0] > num_samples:
        y = y[:num_samples]
    elif y.shape[0] < num_samples:
        padding = num_samples - y.shape[0]
        y = np.pad(y, (0, padding))
    return (sr, y)

def extract_loudness(audio):
    audio = process_audio_clip(audio)
    if audio is None:
        return None
    sr, y = audio
    if y.ndim == 2:
        y = y.mean(axis=1)
    rms = librosa.feature.rms(y=y)[0]
    times = librosa.times_like(rms, sr=sr)

    fig, ax = plt.subplots(figsize=(8, 3))
    ax.plot(times, rms)
    ax.set_title("Loudness (RMS)")
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Energy")
    fig.tight_layout()
    return fig

def extract_pitch(audio):
    audio = process_audio_clip(audio)
    if audio is None:
        return None
    sr, y = audio
    if y.ndim == 2:
        y = y.mean(axis=1)
    f0, voiced_flag, _ = librosa.pyin(
        y,
        fmin=librosa.note_to_hz('C2'),
        fmax=librosa.note_to_hz('C7'),
    )
    times = librosa.times_like(f0, sr=sr)

    fig, ax = plt.subplots(figsize=(8, 3))
    ax.plot(times, f0)
    ax.set_title("Pitch (F0 contour)")
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Frequency (Hz)")
    fig.tight_layout()
    return fig

def visualize_events(json_str):
    try:
        events = json5.loads(json_str)
    except:
        return None

    fig, ax = plt.subplots(figsize=(8, 3))
    cmap = cm.get_cmap("tab10")
    labels = list(events.keys())
    color_map = {label: cmap(i % 10) for i, label in enumerate(labels)}

    for i, (label, intervals) in enumerate(events.items()):
        color = color_map[label]
        for start, end in intervals:
            if start >= MAX_DURATION:
                continue
            end = min(end, MAX_DURATION)
            ax.barh(i, end - start, left=start, height=0.5, color=color)

    ax.set_yticks(range(len(events)))
    ax.set_yticklabels(labels)
    ax.set_xlabel("Time (s)")
    ax.set_title("Sound Events Timeline")
    ax.set_xlim(0, MAX_DURATION)
    fig.tight_layout()
    return fig

# -----------------------------
# AudioControlNet Initialization
# -----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = AudioControlNet.from_multi_controlnets(
    [
        "juhayna/T2A-Adapter-loudness-v1.0",
        "juhayna/T2A-Adapter-pitch-v1.0",
        "juhayna/T2A-Adapter-events-v1.0",
    ],
    device=DEVICE,
)

# -----------------------------
# Temporary WAV utility
# -----------------------------
def save_temp_wav(audio):
    if audio is None:
        return None
    sr, y = audio
    if y.ndim == 2:
        y = y.mean(axis=1)
    y = torch.from_numpy(y).float().unsqueeze(0)
    tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
    torchaudio.save(tmp.name, y, sr)
    return tmp.name

# -----------------------------
# Generate audio
# -----------------------------
@require_gpu
def generate_audio(text, cond_loudness, cond_pitch, cond_events):
    control = {}
    temp_files = []

    try:
        if cond_loudness is not None:
            wav_path = save_temp_wav(cond_loudness)
            temp_files.append(wav_path)
            control["loudness"] = model.prepare_loudness(wav_path)

        elif cond_pitch is not None:
            wav_path = save_temp_wav(cond_pitch)
            temp_files.append(wav_path)
            control["pitch"] = model.prepare_pitch(wav_path)

        elif cond_events:
            events = json5.loads(cond_events)
            control["events"] = events

        with torch.no_grad():
            res = model.infer(
                caption=text,
                control=control if len(control) > 0 else None,
            )

        audio = res.audio.squeeze(0).cpu().numpy()
        sr = res.sample_rate
        return (sr, audio)

    finally:
        for f in temp_files:
            if f and os.path.exists(f):
                os.remove(f)

# -----------------------------
# Gradio Interface
# -----------------------------
blue_theme = gr.themes.Soft(primary_hue="blue", secondary_hue="sky", neutral_hue="slate")

# Generate initial random example for page load
initial_caption, initial_events = generate_random_example()

CAPTION_PLACEHOLDER = 'Water flows and dishes clatter with child speech and laughter.'

EVENTS_PLACEHOLDER = '''
// example
{
    "Child speech, kid speaking": [[0.0, 1.503], [1.732, 2.12], [2.942, 3.541], [7.803, 8.493]],
    "Dishes, pots, and pans": [[1.983, 2.156], [3.175, 3.298], [4.774, 5.076], [5.711, 5.834], [6.076, 6.24], [6.423, 7.012]],
    "Water tap, faucet": [[0.0, 10.0]]
}
'''.strip()

with gr.Blocks(theme=blue_theme, title="Audio ControlNet – Text to Audio") as demo:
    gr.Markdown("""
        # 🎵 Audio ControlNet
        ## Fine-Grained Text-to-Audio Generation with Conditions
        T2A GUI interface with conditional inputs for **Audio ControlNet**.
    """)
    gr.HTML("""
    <style>
    .plot-small { height: 280px !important; }
    </style>
    """)

    with gr.Row():
        with gr.Column(scale=2):
            text_prompt = gr.Textbox(
                label="Text Prompt",
                placeholder=CAPTION_PLACEHOLDER,
                lines=4,
                value=initial_caption,
            )

            with gr.Tabs() as tabs:
                with gr.Tab("Sound Events") as tab_events:
                    with gr.Row():
                        with gr.Column(scale=1):
                            sound_events = gr.Textbox(label="Sound Events (JSON)", placeholder=EVENTS_PLACEHOLDER, lines=8, value=initial_events)
                            random_example_btn = gr.Button("🎲 Random Example", variant="primary", size="sm")
                        with gr.Column(scale=1):
                            events_plot = gr.Plot(label="Sound Events Roll", elem_classes="plot-small")
                            
                with gr.Tab("Loudness") as tab_loudness:
                    with gr.Row():
                        with gr.Column(scale=1):
                            loudness_audio = gr.Audio(label="Loudness Reference Audio (up to 10 sec)", type="numpy")
                        with gr.Column(scale=1):
                            loudness_plot = gr.Plot(label="Loudness Curve (Reference Audio)", elem_classes="plot-small")

                with gr.Tab("Pitch") as tab_pitch:
                    with gr.Row():
                        with gr.Column(scale=1):
                            pitch_audio = gr.Audio(label="Pitch Reference Audio (up to 10 sec)", type="numpy")
                        with gr.Column(scale=1):
                            pitch_plot = gr.Plot(label="Pitch Curve (Reference Audio)", elem_classes="plot-small")

            generate_btn = gr.Button("Generate Audio", variant="primary")

        with gr.Column(scale=1):
            audio_output = gr.Audio(label="Generated Audio", type="numpy")

    loudness_audio.change(fn=extract_loudness, inputs=loudness_audio, outputs=loudness_plot)
    pitch_audio.change(fn=extract_pitch, inputs=pitch_audio, outputs=pitch_plot)
    sound_events.change(fn=visualize_events, inputs=sound_events, outputs=events_plot)
    
    # Initialize events plot with the initial random example
    demo.load(fn=lambda: visualize_events(initial_events), inputs=[], outputs=events_plot)
    
    # Random example button event
    random_example_btn.click(
        fn=generate_random_example,
        inputs=[],
        outputs=[text_prompt, sound_events]
    )

    generate_btn.click(
        fn=generate_audio,
        inputs=[text_prompt, loudness_audio, pitch_audio, sound_events],
        outputs=audio_output
    )

    tab_loudness.select(lambda: (None, None), [], [pitch_audio, sound_events])
    tab_pitch.select(lambda: (None, None), [], [loudness_audio, sound_events])
    tab_events.select(lambda: (None, None), [], [loudness_audio, pitch_audio])

    gr.Markdown("""
        ---
        **Control Inputs**
        - **Loudness**: reference audio controlling energy / dynamics
        - **Pitch**: reference audio controlling pitch contour
        - **Sound Events**: symbolic event-level constraints in JSON format
    """)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", quiet=True)