thecollabagepatch commited on
Commit
ff1abf5
·
1 Parent(s): ed6f2d5

ok herewego

Browse files
Files changed (1) hide show
  1. app.py +112 -105
app.py CHANGED
@@ -18,9 +18,9 @@ def preprocess_audio(waveform):
18
  waveform_np = waveform.cpu().squeeze().numpy()
19
  return torch.from_numpy(waveform_np).unsqueeze(0).to(device)
20
 
21
- # Fix: Add dummy parameter to avoid schema generation bug
22
  @spaces.GPU
23
- def generate_drum_sample(dummy_trigger="generate"):
24
  model = MusicGen.get_pretrained('pharoAIsanders420/micro-musicgen-jungle')
25
  model.set_generation_params(duration=10)
26
  wav = model.generate_unconditional(1).squeeze(0)
@@ -32,140 +32,146 @@ def generate_drum_sample(dummy_trigger="generate"):
32
 
33
  return filename_with_extension
34
 
35
- @spaces.GPU
36
- def continue_drum_sample(existing_audio_path):
37
- if existing_audio_path is None:
38
- return None
 
 
 
 
 
 
39
 
40
- existing_audio, sr = torchaudio.load(existing_audio_path)
41
- existing_audio = existing_audio.to(device)
42
 
43
- prompt_duration = 2
44
- output_duration = 10
45
 
46
- num_samples = int(prompt_duration * sr)
47
- if existing_audio.shape[1] < num_samples:
48
- raise ValueError("The existing audio is too short for the specified prompt duration.")
49
 
50
- start_sample = existing_audio.shape[1] - num_samples
51
- prompt_waveform = existing_audio[..., start_sample:]
52
 
53
- model = MusicGen.get_pretrained('pharoAIsanders420/micro-musicgen-jungle')
54
- model.set_generation_params(duration=output_duration)
55
 
56
- output = model.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
57
- output = output.to(device)
58
 
59
- if output.dim() == 3:
60
- output = output.squeeze(0)
61
 
62
- if output.dim() == 1:
63
- output = output.unsqueeze(0)
64
 
65
- combined_audio = torch.cat((existing_audio, output), dim=1)
66
- combined_audio = combined_audio.cpu()
67
 
68
- combined_file_path = f'./continued_jungle_{random.randint(1000, 9999)}.wav'
69
- torchaudio.save(combined_file_path, combined_audio, sr)
70
 
71
- return combined_file_path
72
 
73
- @spaces.GPU
74
- def generate_music(wav_filename, prompt_duration, musicgen_model, output_duration):
75
- if wav_filename is None:
76
- return None
77
 
78
- song, sr = torchaudio.load(wav_filename)
79
- song = song.to(device)
80
-
81
- model_name = musicgen_model.split(" ")[0]
82
- model_continue = MusicGen.get_pretrained(model_name)
83
-
84
- model_continue.set_generation_params(
85
- use_sampling=True,
86
- top_k=250,
87
- top_p=0.0,
88
- temperature=1.0,
89
- duration=output_duration,
90
- cfg_coef=3
91
- )
92
-
93
- prompt_waveform = song[..., :int(prompt_duration * sr)]
94
- prompt_waveform = preprocess_audio(prompt_waveform)
95
 
96
- output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
97
- output = output.cpu()
98
 
99
- if len(output.size()) > 2:
100
- output = output.squeeze()
101
 
102
- filename_without_extension = f'continued_music'
103
- filename_with_extension = f'{filename_without_extension}.wav'
104
- audio_write(filename_without_extension, output, model_continue.sample_rate, strategy="loudness", loudness_compressor=True)
105
 
106
- return filename_with_extension
107
 
108
- @spaces.GPU
109
- def continue_music(input_audio_path, prompt_duration, musicgen_model, output_duration):
110
- if input_audio_path is None:
111
- return None
112
 
113
- song, sr = torchaudio.load(input_audio_path)
114
- song = song.to(device)
115
 
116
- model_continue = MusicGen.get_pretrained(musicgen_model.split(" ")[0])
117
- model_continue.set_generation_params(
118
- use_sampling=True,
119
- top_k=250,
120
- top_p=0.0,
121
- temperature=1.0,
122
- duration=output_duration,
123
- cfg_coef=3
124
- )
125
 
126
- original_audio = AudioSegment.from_mp3(input_audio_path)
127
- current_audio = original_audio
128
 
129
- file_paths_for_cleanup = []
130
 
131
- for i in range(1):
132
- num_samples = int(prompt_duration * sr)
133
- if current_audio.duration_seconds * 1000 < prompt_duration * 1000:
134
- raise ValueError("The prompt_duration is longer than the current audio length.")
135
 
136
- start_time = current_audio.duration_seconds * 1000 - prompt_duration * 1000
137
- prompt_audio = current_audio[start_time:]
138
 
139
- prompt_bytes = prompt_audio.export(format="wav").read()
140
- prompt_waveform, _ = torchaudio.load(io.BytesIO(prompt_bytes))
141
- prompt_waveform = prompt_waveform.to(device)
142
 
143
- prompt_waveform = preprocess_audio(prompt_waveform)
144
 
145
- output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
146
- output = output.cpu()
147
 
148
- if len(output.size()) > 2:
149
- output = output.squeeze()
150
 
151
- filename_without_extension = f'continue_{i}'
152
- filename_with_extension = f'{filename_without_extension}.wav'
153
- correct_filename_extension = f'{filename_without_extension}.wav.wav'
154
 
155
- audio_write(filename_with_extension, output, model_continue.sample_rate, strategy="loudness", loudness_compressor=True)
156
- generated_audio_segment = AudioSegment.from_wav(correct_filename_extension)
157
 
158
- current_audio = current_audio[:start_time] + generated_audio_segment
159
 
160
- file_paths_for_cleanup.append(correct_filename_extension)
161
 
162
- combined_audio_filename = f"combined_audio_{random.randint(1, 10000)}.mp3"
163
- current_audio.export(combined_audio_filename, format="mp3")
164
 
165
- for file_path in file_paths_for_cleanup:
166
- os.remove(file_path)
167
 
168
- return combined_audio_filename
169
 
170
  # Define the expandable sections
171
  musicgen_micro_blurb = """
@@ -269,10 +275,11 @@ with gr.Blocks() as iface:
269
  hidden_trigger = gr.Textbox(value="generate", visible=False)
270
 
271
  # Fixed click handlers - use hidden input for generate_drum_sample
272
- generate_button.click(generate_drum_sample, inputs=[hidden_trigger], outputs=[drum_audio])
273
- continue_drum_sample_button.click(continue_drum_sample, inputs=[drum_audio], outputs=[drum_audio])
274
- generate_music_button.click(generate_music, inputs=[drum_audio, prompt_duration, musicgen_model, output_duration], outputs=[output_audio])
275
- continue_button.click(continue_music, inputs=[output_audio, prompt_duration, musicgen_model, output_duration], outputs=continue_output_audio)
 
276
 
277
  if __name__ == "__main__":
278
  iface.launch()
 
18
  waveform_np = waveform.cpu().squeeze().numpy()
19
  return torch.from_numpy(waveform_np).unsqueeze(0).to(device)
20
 
21
+ # Test with a wrapper function
22
  @spaces.GPU
23
+ def _generate_drum_sample_internal():
24
  model = MusicGen.get_pretrained('pharoAIsanders420/micro-musicgen-jungle')
25
  model.set_generation_params(duration=10)
26
  wav = model.generate_unconditional(1).squeeze(0)
 
32
 
33
  return filename_with_extension
34
 
35
+ # Regular function wrapper (no @spaces.GPU on this one)
36
+ def generate_drum_sample():
37
+ return _generate_drum_sample_internal()
38
+
39
+
40
+
41
+ # @spaces.GPU
42
+ # def continue_drum_sample(existing_audio_path):
43
+ # if existing_audio_path is None:
44
+ # return None
45
 
46
+ # existing_audio, sr = torchaudio.load(existing_audio_path)
47
+ # existing_audio = existing_audio.to(device)
48
 
49
+ # prompt_duration = 2
50
+ # output_duration = 10
51
 
52
+ # num_samples = int(prompt_duration * sr)
53
+ # if existing_audio.shape[1] < num_samples:
54
+ # raise ValueError("The existing audio is too short for the specified prompt duration.")
55
 
56
+ # start_sample = existing_audio.shape[1] - num_samples
57
+ # prompt_waveform = existing_audio[..., start_sample:]
58
 
59
+ # model = MusicGen.get_pretrained('pharoAIsanders420/micro-musicgen-jungle')
60
+ # model.set_generation_params(duration=output_duration)
61
 
62
+ # output = model.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
63
+ # output = output.to(device)
64
 
65
+ # if output.dim() == 3:
66
+ # output = output.squeeze(0)
67
 
68
+ # if output.dim() == 1:
69
+ # output = output.unsqueeze(0)
70
 
71
+ # combined_audio = torch.cat((existing_audio, output), dim=1)
72
+ # combined_audio = combined_audio.cpu()
73
 
74
+ # combined_file_path = f'./continued_jungle_{random.randint(1000, 9999)}.wav'
75
+ # torchaudio.save(combined_file_path, combined_audio, sr)
76
 
77
+ # return combined_file_path
78
 
79
+ # @spaces.GPU
80
+ # def generate_music(wav_filename, prompt_duration, musicgen_model, output_duration):
81
+ # if wav_filename is None:
82
+ # return None
83
 
84
+ # song, sr = torchaudio.load(wav_filename)
85
+ # song = song.to(device)
86
+
87
+ # model_name = musicgen_model.split(" ")[0]
88
+ # model_continue = MusicGen.get_pretrained(model_name)
89
+
90
+ # model_continue.set_generation_params(
91
+ # use_sampling=True,
92
+ # top_k=250,
93
+ # top_p=0.0,
94
+ # temperature=1.0,
95
+ # duration=output_duration,
96
+ # cfg_coef=3
97
+ # )
98
+
99
+ # prompt_waveform = song[..., :int(prompt_duration * sr)]
100
+ # prompt_waveform = preprocess_audio(prompt_waveform)
101
 
102
+ # output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
103
+ # output = output.cpu()
104
 
105
+ # if len(output.size()) > 2:
106
+ # output = output.squeeze()
107
 
108
+ # filename_without_extension = f'continued_music'
109
+ # filename_with_extension = f'{filename_without_extension}.wav'
110
+ # audio_write(filename_without_extension, output, model_continue.sample_rate, strategy="loudness", loudness_compressor=True)
111
 
112
+ # return filename_with_extension
113
 
114
+ # @spaces.GPU
115
+ # def continue_music(input_audio_path, prompt_duration, musicgen_model, output_duration):
116
+ # if input_audio_path is None:
117
+ # return None
118
 
119
+ # song, sr = torchaudio.load(input_audio_path)
120
+ # song = song.to(device)
121
 
122
+ # model_continue = MusicGen.get_pretrained(musicgen_model.split(" ")[0])
123
+ # model_continue.set_generation_params(
124
+ # use_sampling=True,
125
+ # top_k=250,
126
+ # top_p=0.0,
127
+ # temperature=1.0,
128
+ # duration=output_duration,
129
+ # cfg_coef=3
130
+ # )
131
 
132
+ # original_audio = AudioSegment.from_mp3(input_audio_path)
133
+ # current_audio = original_audio
134
 
135
+ # file_paths_for_cleanup = []
136
 
137
+ # for i in range(1):
138
+ # num_samples = int(prompt_duration * sr)
139
+ # if current_audio.duration_seconds * 1000 < prompt_duration * 1000:
140
+ # raise ValueError("The prompt_duration is longer than the current audio length.")
141
 
142
+ # start_time = current_audio.duration_seconds * 1000 - prompt_duration * 1000
143
+ # prompt_audio = current_audio[start_time:]
144
 
145
+ # prompt_bytes = prompt_audio.export(format="wav").read()
146
+ # prompt_waveform, _ = torchaudio.load(io.BytesIO(prompt_bytes))
147
+ # prompt_waveform = prompt_waveform.to(device)
148
 
149
+ # prompt_waveform = preprocess_audio(prompt_waveform)
150
 
151
+ # output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
152
+ # output = output.cpu()
153
 
154
+ # if len(output.size()) > 2:
155
+ # output = output.squeeze()
156
 
157
+ # filename_without_extension = f'continue_{i}'
158
+ # filename_with_extension = f'{filename_without_extension}.wav'
159
+ # correct_filename_extension = f'{filename_without_extension}.wav.wav'
160
 
161
+ # audio_write(filename_with_extension, output, model_continue.sample_rate, strategy="loudness", loudness_compressor=True)
162
+ # generated_audio_segment = AudioSegment.from_wav(correct_filename_extension)
163
 
164
+ # current_audio = current_audio[:start_time] + generated_audio_segment
165
 
166
+ # file_paths_for_cleanup.append(correct_filename_extension)
167
 
168
+ # combined_audio_filename = f"combined_audio_{random.randint(1, 10000)}.mp3"
169
+ # current_audio.export(combined_audio_filename, format="mp3")
170
 
171
+ # for file_path in file_paths_for_cleanup:
172
+ # os.remove(file_path)
173
 
174
+ # return combined_audio_filename
175
 
176
  # Define the expandable sections
177
  musicgen_micro_blurb = """
 
275
  hidden_trigger = gr.Textbox(value="generate", visible=False)
276
 
277
  # Fixed click handlers - use hidden input for generate_drum_sample
278
+ # Normal click connection
279
+ generate_button.click(generate_drum_sample, outputs=[drum_audio])
280
+ # continue_drum_sample_button.click(continue_drum_sample, inputs=[drum_audio], outputs=[drum_audio])
281
+ # generate_music_button.click(generate_music, inputs=[drum_audio, prompt_duration, musicgen_model, output_duration], outputs=[output_audio])
282
+ # continue_button.click(continue_music, inputs=[output_audio, prompt_duration, musicgen_model, output_duration], outputs=continue_output_audio)
283
 
284
  if __name__ == "__main__":
285
  iface.launch()