ljcortesr commited on
Commit
262705a
·
1 Parent(s): 958e4f6

Model loading once

Browse files
Files changed (1) hide show
  1. app.py +31 -29
app.py CHANGED
@@ -3,42 +3,44 @@ from audiocraft.models import AudioGen
3
  from audiocraft.data.audio import audio_write
4
  import os
5
  import gradio as gr
 
6
 
7
  model = AudioGen.get_pretrained('facebook/audiogen-medium')
8
  model.set_generation_params(duration=5) # generate 5 seconds.
9
 
10
- def generate_audio(descriptions):
11
- if not os.path.exists('audio_files'):
12
- os.makedirs('audio_files')
13
-
14
- wav = model.generate([descriptions]) # generates 3 samples.
15
- results = []
16
-
17
- for idx, one_wav in enumerate(wav):
18
- filename = f'{descriptions}.wav'
19
- file_path = os.path.join('audio_files', filename)
20
- audio_write(file_path, one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True, add_suffix=False)
21
- print(f"Generated audio for '{descriptions}'")
22
- results.append(file_path)
23
-
24
- return results[0]
 
 
 
 
 
 
25
 
26
  def ui_full():
27
- with gr. Blocks() as interface:
28
- gr.Markdown(
29
- """
30
- # AudioGen Demo
31
-
32
- """
33
- )
34
- with gr.Row():
35
- descriptions = gr.Textbox(lines=2, label="Enter descriptions of the audio to generate")
36
  with gr.Row():
37
  generate_button = gr.Button("Generate Audio")
38
- with gr.Row():
39
  output = gr.Audio(label="Generated Audio")
40
 
41
- generate_button.click(fn=generate_audio, inputs=descriptions, outputs=[output])
42
- interface.queue().launch()
43
-
44
- ui_full()
 
3
  from audiocraft.data.audio import audio_write
4
  import os
5
  import gradio as gr
6
+ import spaces
7
 
8
  model = AudioGen.get_pretrained('facebook/audiogen-medium')
9
  model.set_generation_params(duration=5) # generate 5 seconds.
10
 
11
+ OUTPUT_DIR = "audio_files"
12
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
13
+
14
+ @spaces.GPU
15
+ def generate_audio(descriptions: str):
16
+ safe_name = "_".join(descriptions.split())
17
+ output_path = os.path.join(OUTPUT_DIR, safe_name)
18
+
19
+ wav = model.generate([descriptions])
20
+ audio_write(
21
+ output_path,
22
+ wav[0].cpu(),
23
+ model.sample_rate,
24
+ strategy="loudness",
25
+ loudness_compressor=True,
26
+ add_suffix=False,
27
+ )
28
+
29
+ final_path = f"{output_path}.wav"
30
+ print(f"Generated audio for '{descriptions}' -> {final_path}")
31
+ return final_path
32
 
33
  def ui_full():
34
+ with gr.Blocks() as interface:
35
+ gr.Markdown("# AudioGen Demo")
36
+ with gr.Row():
37
+ descriptions = gr.Textbox(lines=2, label="Enter a description of the audio")
 
 
 
 
 
38
  with gr.Row():
39
  generate_button = gr.Button("Generate Audio")
40
+ with gr.Row():
41
  output = gr.Audio(label="Generated Audio")
42
 
43
+ generate_button.click(fn=generate_audio, inputs=descriptions, outputs=output)
44
+ return interface
45
+
46
+ demo = ui_full()