Files changed (1) hide show
  1. app.py +21 -110
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # Will be fixed soon, but meanwhile:
2
  import os
3
  if os.getenv('SPACES_ZERO_GPU') == "true":
4
  os.environ['SPACES_ZERO_GPU'] = "1"
@@ -6,27 +5,20 @@ if os.getenv('SPACES_ZERO_GPU') == "true":
6
  import gradio as gr
7
  import random
8
  import torch
9
- import os
10
  from torch import inference_mode
11
  from typing import Optional, List
12
- import numpy as np
13
  from models import load_model
14
  import utils
15
  import spaces
16
- import huggingface_hub
17
- from inversion_utils import inversion_forward_process, inversion_reverse_process
18
-
19
 
20
  LDM2 = "cvssp/audioldm2"
21
  MUSIC = "cvssp/audioldm2-music"
22
  LDM2_LARGE = "cvssp/audioldm2-large"
23
- STABLEAUD = "stabilityai/stable-audio-open-1.0"
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  ldm2 = load_model(model_id=LDM2, device=device)
26
  ldm2_large = load_model(model_id=LDM2_LARGE, device=device)
27
  ldm2_music = load_model(model_id=MUSIC, device=device)
28
- ldm_stableaud = load_model(model_id=STABLEAUD, device=device, token=os.getenv('PRIV_TOKEN'))
29
-
30
 
31
  def randomize_seed_fn(seed, randomize_seed):
32
  if randomize_seed:
@@ -34,14 +26,10 @@ def randomize_seed_fn(seed, randomize_seed):
34
  torch.manual_seed(seed)
35
  return seed
36
 
37
-
38
  def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, duration, save_compute):
39
- # ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)
40
-
41
  with inference_mode():
42
  w0 = ldm_stable.vae_encode(x0)
43
 
44
- # find Zs and wts - forward process
45
  _, zs, wts, extra_info = inversion_forward_process(ldm_stable, w0, etas=1,
46
  prompts=[prompt_src],
47
  cfg_scales=[cfg_scale_src],
@@ -51,9 +39,7 @@ def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, durat
51
  save_compute=save_compute)
52
  return zs, wts, extra_info
53
 
54
-
55
  def sample(ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, duration, save_compute):
56
- # reverse process (via Zs and wT)
57
  tstart = torch.tensor(tstart, dtype=torch.int)
58
  w0, _ = inversion_reverse_process(ldm_stable, xT=wts, tstart=tstart,
59
  etas=1., prompts=[prompt_tar],
@@ -63,22 +49,17 @@ def sample(ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, d
63
  extra_info=extra_info,
64
  save_compute=save_compute)
65
 
66
- # vae decode image
67
  with inference_mode():
68
  x0_dec = ldm_stable.vae_decode(w0)
69
 
70
- if 'stable-audio' not in ldm_stable.model_id:
71
- if x0_dec.dim() < 4:
72
- x0_dec = x0_dec[None, :, :, :]
73
 
74
- with torch.no_grad():
75
- audio = ldm_stable.decode_to_mel(x0_dec)
76
- else:
77
- audio = x0_dec.squeeze(0).T
78
 
79
  return (ldm_stable.get_sr(), audio.squeeze().cpu().numpy())
80
 
81
-
82
  def get_duration(input_audio,
83
  model_id: str,
84
  do_inversion: bool,
@@ -91,60 +72,37 @@ def get_duration(input_audio,
91
  cfg_scale_tar: float = 12,
92
  t_start: int = 45,
93
  randomize_seed: bool = True,
94
- save_compute: bool = True,
95
- oauth_token: Optional[gr.OAuthToken] = None):
96
  if model_id == LDM2:
97
  factor = 1
98
  elif model_id == LDM2_LARGE:
99
  factor = 2.5
100
- elif model_id == STABLEAUD:
101
- factor = 3.2
102
  else: # MUSIC
103
  factor = 1
104
 
105
  forwards = 0
106
  if do_inversion or randomize_seed:
107
- forwards = steps if source_prompt == "" else steps * 2 # x2 when there is a prompt text
108
  forwards += int(t_start / 100 * steps) * 2
109
 
110
  duration = min(utils.get_duration(input_audio), utils.MAX_DURATION)
111
- time_for_maxlength = factor * forwards * 0.15 # 0.25 is the time per forward pass
112
 
113
- if model_id != STABLEAUD:
114
- time_for_maxlength = time_for_maxlength / utils.MAX_DURATION * duration
115
 
116
  print('expected time:', time_for_maxlength)
117
  spare_time = 5
118
  return max(10, time_for_maxlength + spare_time)
119
 
120
-
121
- def verify_model_params(model_id: str, input_audio, src_prompt: str, tar_prompt: str, cfg_scale_src: float,
122
- oauth_token: gr.OAuthToken | None):
123
  if input_audio is None:
124
  raise gr.Error('Input audio missing!')
125
 
126
  if tar_prompt == "":
127
  raise gr.Error("Please provide a target prompt to edit the audio.")
128
 
129
- if src_prompt != "":
130
- if model_id == STABLEAUD and cfg_scale_src != 1:
131
- gr.Info("Consider using Source Guidance Scale=1 for Stable Audio Open 1.0.")
132
- elif model_id != STABLEAUD and cfg_scale_src != 3:
133
- gr.Info(f"Consider using Source Guidance Scale=3 for {model_id}.")
134
-
135
- if model_id == STABLEAUD:
136
- if oauth_token is None:
137
- raise gr.Error("You must be logged in to use Stable Audio Open 1.0. Please log in and try again.")
138
- try:
139
- huggingface_hub.get_hf_file_metadata(huggingface_hub.hf_hub_url(STABLEAUD, 'transformer/config.json'),
140
- token=oauth_token.token)
141
- print('Has Access')
142
- # except huggingface_hub.utils._errors.GatedRepoError:
143
- except huggingface_hub.errors.GatedRepoError:
144
- raise gr.Error("You need to accept the license agreement to use Stable Audio Open 1.0. "
145
- "Visit the <a href='https://huggingface.co/stabilityai/stable-audio-open-1.0'>"
146
- "model page</a> to get access.")
147
-
148
 
149
  @spaces.GPU(duration=get_duration)
150
  def edit(input_audio,
@@ -159,32 +117,28 @@ def edit(input_audio,
159
  cfg_scale_tar: float = 12,
160
  t_start: int = 45,
161
  randomize_seed: bool = True,
162
- save_compute: bool = True,
163
- oauth_token: Optional[gr.OAuthToken] = None):
164
  print(model_id)
165
  if model_id == LDM2:
166
  ldm_stable = ldm2
167
  elif model_id == LDM2_LARGE:
168
  ldm_stable = ldm2_large
169
- elif model_id == STABLEAUD:
170
- ldm_stable = ldm_stableaud
171
  else: # MUSIC
172
  ldm_stable = ldm2_music
173
 
174
  ldm_stable.model.scheduler.set_timesteps(steps, device=device)
175
 
176
- # If the inversion was done for a different model, we need to re-run the inversion
177
  if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id):
178
  do_inversion = True
179
 
180
  if input_audio is None:
181
  raise gr.Error('Input audio missing!')
182
  x0, _, duration = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device,
183
- stft=('stable-audio' not in ldm_stable.model_id), model_sr=ldm_stable.get_sr())
184
  if wts is None or zs is None:
185
  do_inversion = True
186
 
187
- if do_inversion or randomize_seed: # always re-run inversion
188
  zs_tensor, wts_tensor, extra_info_list = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
189
  num_diffusion_steps=steps,
190
  cfg_scale_src=cfg_scale_src,
@@ -205,8 +159,6 @@ def edit(input_audio,
205
  save_compute=save_compute)
206
 
207
  return output, wts.cpu(), zs.cpu(), [e.cpu() for e in extra_info if e is not None], saved_inv_model, do_inversion
208
- # return output, wtszs_file, saved_inv_model, do_inversion
209
-
210
 
211
  def get_example():
212
  case = [
@@ -226,14 +178,6 @@ def get_example():
226
  '27s',
227
  'Examples/Beethoven_piano.mp3',
228
  ],
229
- ['Examples/Beethoven.mp3',
230
- '',
231
- 'Heavy Rock.',
232
- 40,
233
- 'stabilityai/stable-audio-open-1.0',
234
- '27s',
235
- 'Examples/Beethoven_rock.mp3',
236
- ],
237
  ['Examples/ModalJazz.mp3',
238
  'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.',
239
  'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.',
@@ -241,13 +185,6 @@ def get_example():
241
  'cvssp/audioldm2-music',
242
  '106s',
243
  'Examples/ModalJazz_banjo.mp3',],
244
- ['Examples/Shadows.mp3',
245
- '',
246
- '8-bit arcade game soundtrack.',
247
- 40,
248
- 'stabilityai/stable-audio-open-1.0',
249
- '34s',
250
- 'Examples/Shadows_arcade.mp3',],
251
  ['Examples/Cat.mp3',
252
  '',
253
  'A dog barking.',
@@ -258,14 +195,13 @@ def get_example():
258
  ]
259
  return case
260
 
261
-
262
  intro = """
263
  <h1 style="font-weight: 1000; text-align: center; margin: 0px;"> ZETA Editing 🎧 </h1>
264
  <h2 style="font-weight: 1000; text-align: center; margin: 0px;">
265
  Zero-Shot Text-Based Audio Editing Using DDPM Inversion 🎛️ </h2>
266
  <h3 style="margin-top: 0px; margin-bottom: 10px; text-align: center;">
267
- <a href="https://arxiv.org/abs/2402.10009">[Paper]</a>&nbsp;|&nbsp;
268
- <a href="https://hilamanor.github.io/AudioEditing/">[Project page]</a>&nbsp;|&nbsp;
269
  <a href="https://github.com/HilaManor/AudioEditingCode">[Code]</a>
270
  </h3>
271
 
@@ -275,22 +211,6 @@ For faster inference without waiting in queue, you may duplicate the space and u
275
  <img style="margin-top: 0em; margin-bottom: 0em; display:inline" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" >
276
  </a>
277
  </p>
278
- <p style="margin: 0px;">
279
- <b>NEW - 15.10.24:</b> You can now edit using <b>Stable Audio Open 1.0</b>.
280
- You must be <b>logged in</b> after accepting the
281
- <b><a href="https://huggingface.co/stabilityai/stable-audio-open-1.0">license agreement</a></b> to use it.</br>
282
- </p>
283
- <ul style="padding-left:40px; line-height:normal;">
284
- <li style="margin: 0px;">Prompts behave differently - e.g.,
285
- try "8-bit arcade" directly instead of "a recording of...". Check out the new examples below!</li>
286
- <li style="margin: 0px;">Try to play around <code>T-start=40%</code>.</li>
287
- <li style="margin: 0px;">Under "More Options": Use <code>Source Guidance Scale=1</code>,
288
- and you can try fewer timesteps (even 20!).</li>
289
- <li style="margin: 0px;">Stable Audio Open is a general-audio model.
290
- For better music editing, duplicate the space and change to a
291
- <a href="https://huggingface.co/models?other=base_model:finetune:stabilityai/stable-audio-open-1.0">
292
- fine-tuned model for music</a>.</li>
293
- </ul>
294
  <p>
295
  <b>NEW - 15.10.24:</b> Parallel editing is enabled by default.
296
  To disable, uncheck <code>Efficient editing</code> under "More Options".
@@ -298,7 +218,6 @@ Saves a bit of time.
298
  </p>
299
  """
300
 
301
-
302
  help = """
303
  <div style="font-size:medium">
304
  <b>Instructions:</b><br>
@@ -319,21 +238,17 @@ to <code style="display:inline; background-color: lightgrey;">None</code>.
319
  </li>
320
  </ul>
321
  </div>
322
-
323
  """
324
 
325
  css = '.gradio-container {max-width: 1000px !important; padding-top: 1.5rem !important;}' \
326
  '.audio-upload .wrap {min-height: 0px;}'
327
 
328
- # with gr.Blocks(css='style.css') as demo:
329
  with gr.Blocks(css=css) as demo:
330
  def reset_do_inversion(do_inversion_user, do_inversion):
331
- # do_inversion = gr.State(value=True)
332
  do_inversion = True
333
  do_inversion_user = True
334
  return do_inversion_user, do_inversion
335
 
336
- # handle the case where the user clicked the button but the inversion was not done
337
  def clear_do_inversion_user(do_inversion_user):
338
  do_inversion_user = False
339
  return do_inversion_user
@@ -350,7 +265,7 @@ with gr.Blocks(css=css) as demo:
350
  zs = gr.State()
351
  extra_info = gr.State()
352
  saved_inv_model = gr.State()
353
- do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over
354
  do_inversion_user = gr.State(value=False)
355
 
356
  with gr.Group():
@@ -371,15 +286,12 @@ with gr.Blocks(css=css) as demo:
371
  t_start = gr.Slider(minimum=15, maximum=85, value=45, step=1, label="T-start (%)", interactive=True, scale=3,
372
  info="Lower T-start -> closer to original audio. Higher T-start -> stronger edit.")
373
  model_id = gr.Dropdown(label="Model Version",
374
- choices=[LDM2,
375
- LDM2_LARGE,
376
- MUSIC,
377
- STABLEAUD],
378
  info="Choose a checkpoint suitable for your audio and edit",
379
  value="cvssp/audioldm2-music", interactive=True, type="value", scale=2)
 
380
  with gr.Row():
381
  submit = gr.Button("Edit", variant="primary", scale=3)
382
- gr.LoginButton(value="Login to HF (For Stable Audio)", scale=1)
383
 
384
  with gr.Accordion("More Options", open=False):
385
  with gr.Row():
@@ -435,7 +347,6 @@ with gr.Blocks(css=css) as demo:
435
  outputs=[do_inversion_user, do_inversion]
436
  )
437
 
438
- # If sources changed we have to rerun inversion
439
  gr.on(
440
  triggers=[input_audio.change, src_prompt.change, model_id.change, cfg_scale_src.change,
441
  steps.change, save_compute.change],
@@ -452,4 +363,4 @@ with gr.Blocks(css=css) as demo:
452
  )
453
 
454
  demo.queue()
455
- demo.launch(state_session_capacity=15)
 
 
1
  import os
2
  if os.getenv('SPACES_ZERO_GPU') == "true":
3
  os.environ['SPACES_ZERO_GPU'] = "1"
 
5
  import gradio as gr
6
  import random
7
  import torch
8
+ import numpy as np
9
  from torch import inference_mode
10
  from typing import Optional, List
 
11
  from models import load_model
12
  import utils
13
  import spaces
 
 
 
14
 
15
  LDM2 = "cvssp/audioldm2"
16
  MUSIC = "cvssp/audioldm2-music"
17
  LDM2_LARGE = "cvssp/audioldm2-large"
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  ldm2 = load_model(model_id=LDM2, device=device)
20
  ldm2_large = load_model(model_id=LDM2_LARGE, device=device)
21
  ldm2_music = load_model(model_id=MUSIC, device=device)
 
 
22
 
23
  def randomize_seed_fn(seed, randomize_seed):
24
  if randomize_seed:
 
26
  torch.manual_seed(seed)
27
  return seed
28
 
 
29
  def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, duration, save_compute):
 
 
30
  with inference_mode():
31
  w0 = ldm_stable.vae_encode(x0)
32
 
 
33
  _, zs, wts, extra_info = inversion_forward_process(ldm_stable, w0, etas=1,
34
  prompts=[prompt_src],
35
  cfg_scales=[cfg_scale_src],
 
39
  save_compute=save_compute)
40
  return zs, wts, extra_info
41
 
 
42
  def sample(ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, duration, save_compute):
 
43
  tstart = torch.tensor(tstart, dtype=torch.int)
44
  w0, _ = inversion_reverse_process(ldm_stable, xT=wts, tstart=tstart,
45
  etas=1., prompts=[prompt_tar],
 
49
  extra_info=extra_info,
50
  save_compute=save_compute)
51
 
 
52
  with inference_mode():
53
  x0_dec = ldm_stable.vae_decode(w0)
54
 
55
+ if x0_dec.dim() < 4:
56
+ x0_dec = x0_dec[None, :, :, :]
 
57
 
58
+ with torch.no_grad():
59
+ audio = ldm_stable.decode_to_mel(x0_dec)
 
 
60
 
61
  return (ldm_stable.get_sr(), audio.squeeze().cpu().numpy())
62
 
 
63
  def get_duration(input_audio,
64
  model_id: str,
65
  do_inversion: bool,
 
72
  cfg_scale_tar: float = 12,
73
  t_start: int = 45,
74
  randomize_seed: bool = True,
75
+ save_compute: bool = True):
 
76
  if model_id == LDM2:
77
  factor = 1
78
  elif model_id == LDM2_LARGE:
79
  factor = 2.5
 
 
80
  else: # MUSIC
81
  factor = 1
82
 
83
  forwards = 0
84
  if do_inversion or randomize_seed:
85
+ forwards = steps if source_prompt == "" else steps * 2
86
  forwards += int(t_start / 100 * steps) * 2
87
 
88
  duration = min(utils.get_duration(input_audio), utils.MAX_DURATION)
89
+ time_for_maxlength = factor * forwards * 0.15
90
 
91
+ time_for_maxlength = time_for_maxlength / utils.MAX_DURATION * duration
 
92
 
93
  print('expected time:', time_for_maxlength)
94
  spare_time = 5
95
  return max(10, time_for_maxlength + spare_time)
96
 
97
+ def verify_model_params(model_id: str, input_audio, src_prompt: str, tar_prompt: str, cfg_scale_src: float):
 
 
98
  if input_audio is None:
99
  raise gr.Error('Input audio missing!')
100
 
101
  if tar_prompt == "":
102
  raise gr.Error("Please provide a target prompt to edit the audio.")
103
 
104
+ if src_prompt != "" and cfg_scale_src != 3:
105
+ gr.Info(f"Consider using Source Guidance Scale=3 for {model_id}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  @spaces.GPU(duration=get_duration)
108
  def edit(input_audio,
 
117
  cfg_scale_tar: float = 12,
118
  t_start: int = 45,
119
  randomize_seed: bool = True,
120
+ save_compute: bool = True):
 
121
  print(model_id)
122
  if model_id == LDM2:
123
  ldm_stable = ldm2
124
  elif model_id == LDM2_LARGE:
125
  ldm_stable = ldm2_large
 
 
126
  else: # MUSIC
127
  ldm_stable = ldm2_music
128
 
129
  ldm_stable.model.scheduler.set_timesteps(steps, device=device)
130
 
 
131
  if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id):
132
  do_inversion = True
133
 
134
  if input_audio is None:
135
  raise gr.Error('Input audio missing!')
136
  x0, _, duration = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device,
137
+ stft=True, model_sr=ldm_stable.get_sr())
138
  if wts is None or zs is None:
139
  do_inversion = True
140
 
141
+ if do_inversion or randomize_seed:
142
  zs_tensor, wts_tensor, extra_info_list = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
143
  num_diffusion_steps=steps,
144
  cfg_scale_src=cfg_scale_src,
 
159
  save_compute=save_compute)
160
 
161
  return output, wts.cpu(), zs.cpu(), [e.cpu() for e in extra_info if e is not None], saved_inv_model, do_inversion
 
 
162
 
163
  def get_example():
164
  case = [
 
178
  '27s',
179
  'Examples/Beethoven_piano.mp3',
180
  ],
 
 
 
 
 
 
 
 
181
  ['Examples/ModalJazz.mp3',
182
  'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.',
183
  'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.',
 
185
  'cvssp/audioldm2-music',
186
  '106s',
187
  'Examples/ModalJazz_banjo.mp3',],
 
 
 
 
 
 
 
188
  ['Examples/Cat.mp3',
189
  '',
190
  'A dog barking.',
 
195
  ]
196
  return case
197
 
 
198
  intro = """
199
  <h1 style="font-weight: 1000; text-align: center; margin: 0px;"> ZETA Editing 🎧 </h1>
200
  <h2 style="font-weight: 1000; text-align: center; margin: 0px;">
201
  Zero-Shot Text-Based Audio Editing Using DDPM Inversion 🎛️ </h2>
202
  <h3 style="margin-top: 0px; margin-bottom: 10px; text-align: center;">
203
+ <a href="https://arxiv.org/abs/2402.10009">[Paper]</a> |
204
+ <a href="https://hilamanor.github.io/AudioEditing/">[Project page]</a> |
205
  <a href="https://github.com/HilaManor/AudioEditingCode">[Code]</a>
206
  </h3>
207
 
 
211
  <img style="margin-top: 0em; margin-bottom: 0em; display:inline" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" >
212
  </a>
213
  </p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  <p>
215
  <b>NEW - 15.10.24:</b> Parallel editing is enabled by default.
216
  To disable, uncheck <code>Efficient editing</code> under "More Options".
 
218
  </p>
219
  """
220
 
 
221
  help = """
222
  <div style="font-size:medium">
223
  <b>Instructions:</b><br>
 
238
  </li>
239
  </ul>
240
  </div>
 
241
  """
242
 
243
  css = '.gradio-container {max-width: 1000px !important; padding-top: 1.5rem !important;}' \
244
  '.audio-upload .wrap {min-height: 0px;}'
245
 
 
246
  with gr.Blocks(css=css) as demo:
247
  def reset_do_inversion(do_inversion_user, do_inversion):
 
248
  do_inversion = True
249
  do_inversion_user = True
250
  return do_inversion_user, do_inversion
251
 
 
252
  def clear_do_inversion_user(do_inversion_user):
253
  do_inversion_user = False
254
  return do_inversion_user
 
265
  zs = gr.State()
266
  extra_info = gr.State()
267
  saved_inv_model = gr.State()
268
+ do_inversion = gr.State(value=True)
269
  do_inversion_user = gr.State(value=False)
270
 
271
  with gr.Group():
 
286
  t_start = gr.Slider(minimum=15, maximum=85, value=45, step=1, label="T-start (%)", interactive=True, scale=3,
287
  info="Lower T-start -> closer to original audio. Higher T-start -> stronger edit.")
288
  model_id = gr.Dropdown(label="Model Version",
289
+ choices=[LDM2, LDM2_LARGE, MUSIC],
 
 
 
290
  info="Choose a checkpoint suitable for your audio and edit",
291
  value="cvssp/audioldm2-music", interactive=True, type="value", scale=2)
292
+
293
  with gr.Row():
294
  submit = gr.Button("Edit", variant="primary", scale=3)
 
295
 
296
  with gr.Accordion("More Options", open=False):
297
  with gr.Row():
 
347
  outputs=[do_inversion_user, do_inversion]
348
  )
349
 
 
350
  gr.on(
351
  triggers=[input_audio.change, src_prompt.change, model_id.change, cfg_scale_src.change,
352
  steps.change, save_compute.change],
 
363
  )
364
 
365
  demo.queue()
366
+ demo.launch(state_session_capacity=15)