chenxie95 commited on
Commit
b419370
·
1 Parent(s): d1d9955

add app code

Browse files
Files changed (1) hide show
  1. app.py +244 -5
app.py CHANGED
@@ -1,9 +1,248 @@
1
  import gradio as gr
2
  import torch
3
- import audio_controlnet
 
 
 
 
 
 
 
 
4
 
5
- def greet(name):
6
- return f"Hello {name}!! Torch is {torch.__version__}. Cuda is available: {torch.cuda.is_available()}. audio_controlnet: {audio_controlnet}"
7
 
8
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
9
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import numpy as np
4
+ import librosa
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.cm as cm
7
+ import json5
8
+ import torchaudio
9
+ import tempfile
10
+ import os
11
+ from audio_controlnet.infer import AudioControlNet
12
 
13
+ MAX_DURATION = 10.0 # seconds
 
14
 
15
+ # -----------------------------
16
+ # Feature extraction utilities
17
+ # -----------------------------
18
+ def process_audio_clip(audio):
19
+ if audio is None:
20
+ return None
21
+ sr, y = audio
22
+ y = y.astype(np.float32)
23
+ num_samples = int(MAX_DURATION * sr)
24
+ if y.shape[0] > num_samples:
25
+ y = y[:num_samples]
26
+ elif y.shape[0] < num_samples:
27
+ padding = num_samples - y.shape[0]
28
+ y = np.pad(y, (0, padding))
29
+ return (sr, y)
30
+
31
+ def extract_loudness(audio):
32
+ audio = process_audio_clip(audio)
33
+ if audio is None:
34
+ return None
35
+ sr, y = audio
36
+ if y.ndim == 2:
37
+ y = y.mean(axis=1)
38
+ rms = librosa.feature.rms(y=y)[0]
39
+ times = librosa.times_like(rms, sr=sr)
40
+
41
+ fig, ax = plt.subplots(figsize=(8, 3))
42
+ ax.plot(times, rms)
43
+ ax.set_title("Loudness (RMS)")
44
+ ax.set_xlabel("Time (s)")
45
+ ax.set_ylabel("Energy")
46
+ fig.tight_layout()
47
+ return fig
48
+
49
+ def extract_pitch(audio):
50
+ audio = process_audio_clip(audio)
51
+ if audio is None:
52
+ return None
53
+ sr, y = audio
54
+ if y.ndim == 2:
55
+ y = y.mean(axis=1)
56
+ f0, voiced_flag, _ = librosa.pyin(
57
+ y,
58
+ fmin=librosa.note_to_hz('C2'),
59
+ fmax=librosa.note_to_hz('C7'),
60
+ )
61
+ times = librosa.times_like(f0, sr=sr)
62
+
63
+ fig, ax = plt.subplots(figsize=(8, 3))
64
+ ax.plot(times, f0)
65
+ ax.set_title("Pitch (F0 contour)")
66
+ ax.set_xlabel("Time (s)")
67
+ ax.set_ylabel("Frequency (Hz)")
68
+ fig.tight_layout()
69
+ return fig
70
+
71
+ def visualize_events(json_str):
72
+ try:
73
+ events = json5.loads(json_str)
74
+ except:
75
+ return None
76
+
77
+ fig, ax = plt.subplots(figsize=(8, 3))
78
+ cmap = cm.get_cmap("tab10")
79
+ labels = list(events.keys())
80
+ color_map = {label: cmap(i % 10) for i, label in enumerate(labels)}
81
+
82
+ for i, (label, intervals) in enumerate(events.items()):
83
+ color = color_map[label]
84
+ for start, end in intervals:
85
+ if start >= MAX_DURATION:
86
+ continue
87
+ end = min(end, MAX_DURATION)
88
+ ax.barh(i, end - start, left=start, height=0.5, color=color)
89
+
90
+ ax.set_yticks(range(len(events)))
91
+ ax.set_yticklabels(labels)
92
+ ax.set_xlabel("Time (s)")
93
+ ax.set_title("Sound Events Timeline")
94
+ ax.set_xlim(0, MAX_DURATION)
95
+ fig.tight_layout()
96
+ return fig
97
+
98
+ # -----------------------------
99
+ # AudioControlNet Initialization
100
+ # -----------------------------
101
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
102
+ model = AudioControlNet.from_multi_controlnets(
103
+ [
104
+ "juhayna/T2A-Adapter-loudness-v1.0",
105
+ "juhayna/T2A-Adapter-pitch-v1.0",
106
+ "juhayna/T2A-Adapter-events-v1.0",
107
+ ],
108
+ device=DEVICE,
109
+ )
110
+
111
+ # -----------------------------
112
+ # Temporary WAV utility
113
+ # -----------------------------
114
+ def save_temp_wav(audio):
115
+ if audio is None:
116
+ return None
117
+ sr, y = audio
118
+ if y.ndim == 2:
119
+ y = y.mean(axis=1)
120
+ y = torch.from_numpy(y).float().unsqueeze(0)
121
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
122
+ torchaudio.save(tmp.name, y, sr)
123
+ return tmp.name
124
+
125
+ # -----------------------------
126
+ # Generate audio
127
+ # -----------------------------
128
+ def generate_audio(text, cond_loudness, cond_pitch, cond_events):
129
+ control = {}
130
+ temp_files = []
131
+
132
+ try:
133
+ if cond_loudness is not None:
134
+ wav_path = save_temp_wav(cond_loudness)
135
+ temp_files.append(wav_path)
136
+ control["loudness"] = model.prepare_loudness(wav_path)
137
+
138
+ elif cond_pitch is not None:
139
+ wav_path = save_temp_wav(cond_pitch)
140
+ temp_files.append(wav_path)
141
+ control["pitch"] = model.prepare_pitch(wav_path)
142
+
143
+ elif cond_events:
144
+ events = json5.loads(cond_events)
145
+ control["events"] = events
146
+
147
+ with torch.no_grad():
148
+ res = model.infer(
149
+ caption=text,
150
+ control=control if len(control) > 0 else None,
151
+ )
152
+
153
+ audio = res.audio.squeeze(0).cpu().numpy()
154
+ sr = res.sample_rate
155
+ return (sr, audio)
156
+
157
+ finally:
158
+ for f in temp_files:
159
+ if f and os.path.exists(f):
160
+ os.remove(f)
161
+
162
+ # -----------------------------
163
+ # Gradio Interface
164
+ # -----------------------------
165
+ blue_theme = gr.themes.Soft(primary_hue="blue", secondary_hue="sky", neutral_hue="slate")
166
+
167
+ EVENTS_PLACEHOLDER = '''
168
+ // example
169
+ {
170
+ "Video game sound": [[0.0, 10.0]],
171
+ "Male speech, man speaking": [[0.015, 3.829], [4.293, 4.875], [5.089, 7.349], [8.071, 9.978]]
172
+ }
173
+ '''.strip()
174
+
175
+ with gr.Blocks(theme=blue_theme, title="Audio ControlNet – Text to Audio") as demo:
176
+ gr.Markdown("""
177
+ # 🎵 Audio ControlNet
178
+ ## Text-to-Audio Generation with Conditions
179
+ Base T2A interface with conditional inputs for **Audio ControlNet**.
180
+ """)
181
+ gr.HTML("""
182
+ <style>
183
+ .plot-small { height: 250px !important; }
184
+ </style>
185
+ """)
186
+
187
+ with gr.Row():
188
+ with gr.Column(scale=2):
189
+ text_prompt = gr.Textbox(
190
+ label="Text Prompt",
191
+ placeholder="A calm ambient soundscape with soft pads and distant piano",
192
+ lines=4,
193
+ )
194
+
195
+ with gr.Tabs() as tabs:
196
+ with gr.Tab("Loudness") as tab_loudness:
197
+ with gr.Row():
198
+ with gr.Column(scale=1):
199
+ loudness_audio = gr.Audio(label="Loudness Reference Audio (up to 10 sec)", type="numpy")
200
+ with gr.Column(scale=1):
201
+ loudness_plot = gr.Plot(label="Loudness Curve (Reference Audio)", elem_classes="plot-small")
202
+
203
+ with gr.Tab("Pitch") as tab_pitch:
204
+ with gr.Row():
205
+ with gr.Column(scale=1):
206
+ pitch_audio = gr.Audio(label="Pitch Reference Audio (up to 10 sec)", type="numpy")
207
+ with gr.Column(scale=1):
208
+ pitch_plot = gr.Plot(label="Pitch Curve (Reference Audio)", elem_classes="plot-small")
209
+
210
+ with gr.Tab("Sound Events") as tab_events:
211
+ with gr.Row():
212
+ with gr.Column(scale=1):
213
+ sound_events = gr.Textbox(label="Sound Events (JSON)", placeholder=EVENTS_PLACEHOLDER, lines=8)
214
+ with gr.Column(scale=1):
215
+ events_plot = gr.Plot(label="Sound Events Roll", elem_classes="plot-small")
216
+
217
+ generate_btn = gr.Button("Generate Audio", variant="primary")
218
+
219
+ with gr.Column(scale=1):
220
+ audio_output = gr.Audio(label="Generated Audio", type="numpy")
221
+
222
+ loudness_audio.change(fn=extract_loudness, inputs=loudness_audio, outputs=loudness_plot)
223
+ pitch_audio.change(fn=extract_pitch, inputs=pitch_audio, outputs=pitch_plot)
224
+ sound_events.change(fn=visualize_events, inputs=sound_events, outputs=events_plot)
225
+
226
+ generate_btn.click(
227
+ fn=generate_audio,
228
+ inputs=[text_prompt, loudness_audio, pitch_audio, sound_events],
229
+ outputs=audio_output
230
+ )
231
+
232
+ tab_loudness.select(lambda: (None, None), [], [pitch_audio, sound_events])
233
+ tab_pitch.select(lambda: (None, None), [], [loudness_audio, sound_events])
234
+ tab_events.select(lambda: (None, None), [], [loudness_audio, pitch_audio])
235
+
236
+ gr.Markdown("""
237
+ ---
238
+ **Control Inputs**
239
+ - **Loudness**: reference audio controlling energy / dynamics
240
+ - **Pitch**: reference audio controlling pitch contour
241
+ - **Sound Events**: symbolic event-level constraints in JSON format
242
+ """)
243
+
244
+ if __name__ == "__main__":
245
+ demo.launch(
246
+ server_name="0.0.0.0",
247
+ server_port=7860
248
+ )