ghostai1 commited on
Commit
798e897
·
verified ·
1 Parent(s): e433b2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -34
app.py CHANGED
@@ -10,7 +10,6 @@ import gradio as gr
10
  from pydub import AudioSegment
11
  from audiocraft.models import MusicGen
12
  from torch.cuda.amp import autocast
13
- from pydub.effects import reverb
14
 
15
  # Set PYTORCH_CUDA_ALLOC_CONF to manage memory fragmentation
16
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
@@ -82,25 +81,28 @@ def apply_chorus(segment):
82
  delayed = delayed.set_frame_rate(segment.frame_rate)
83
  return segment.overlay(delayed, position=20)
84
 
 
 
 
 
 
 
 
 
 
85
  def apply_eq(segment):
86
  # Adjusted EQ for clarity in midrange
87
  segment = segment.low_pass_filter(8000)
88
  segment = segment.high_pass_filter(80)
89
- # Boost midrange frequencies (500 Hz to 2 kHz) for clarity
90
  segment = segment.equalizer(frequency=1000, gain=2, q=1.0)
91
  return segment
92
 
93
- def apply_reverb(segment):
94
- # Add subtle reverb for depth
95
- return reverb(segment, reverb_time=1500, wet_level=0.2)
96
-
97
  def apply_limiter(segment, max_db=-3.0):
98
  if segment.dBFS > max_db:
99
  segment = segment - (segment.dBFS - max_db)
100
  return segment
101
 
102
  def apply_final_gain(segment, target_db=-12.0):
103
- # Adjust final gain to a safe loudness level
104
  gain_adjustment = target_db - segment.dBFS
105
  return segment + gain_adjustment
106
 
@@ -113,25 +115,21 @@ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p
113
  start_time = time.time()
114
 
115
  total_duration = min(max(total_duration, 10), 90)
116
- chunk_duration = 15
117
- num_chunks = max(2, (total_duration + chunk_duration - 1) // chunk_duration)
118
- chunk_duration = total_duration / num_chunks
119
-
120
- overlap_duration = min(1.0, crossfade_duration / 1000.0)
121
- generation_duration = chunk_duration + overlap_duration
122
 
123
  audio_chunks = []
124
  sample_rate = musicgen_model.sample_rate
125
 
126
  for i in range(num_chunks):
127
- chunk_prompt = instrumental_prompt # Use the same prompt for all chunks
128
  print(f"Generating chunk {i+1}/{num_chunks} on GPU (prompt: {chunk_prompt})...")
129
  musicgen_model.set_generation_params(
130
- duration=generation_duration,
131
  use_sampling=True,
132
- top_k=top_k,
133
- top_p=top_p,
134
- temperature=temperature,
135
  cfg_coef=cfg_scale
136
  )
137
 
@@ -155,13 +153,16 @@ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p
155
  if audio_chunk.shape[0] != 2:
156
  raise ValueError(f"Expected stereo audio with shape (2, samples), got shape {audio_chunk.shape}")
157
 
158
- temp_wav_path = f"temp_chunk_{i}.wav"
159
- chunk_path = f"chunk_{i}.mp3"
160
- torchaudio.save(temp_wav_path, audio_chunk, sample_rate, bits_per_sample=24)
161
- segment = AudioSegment.from_wav(temp_wav_path)
162
- segment.export(chunk_path, format="mp3", bitrate="320k")
163
- os.remove(temp_wav_path)
164
- audio_chunks.append(chunk_path)
 
 
 
165
 
166
  torch.cuda.empty_cache()
167
  gc.collect()
@@ -169,9 +170,9 @@ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p
169
  print_resource_usage(f"After Chunk {i+1} Generation")
170
 
171
  print("Combining audio chunks...")
172
- final_segment = AudioSegment.from_mp3(audio_chunks[0])
173
  for i in range(1, len(audio_chunks)):
174
- next_segment = AudioSegment.from_mp3(audio_chunks[i])
175
  next_segment = next_segment + 1
176
  final_segment = final_segment.append(next_segment, crossfade=crossfade_duration)
177
 
@@ -194,9 +195,6 @@ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p
194
  )
195
  print(f"Saved final audio to {mp3_path}")
196
 
197
- for chunk_path in audio_chunks:
198
- os.remove(chunk_path)
199
-
200
  print_resource_usage("After Final Generation")
201
  print(f"Total Generation Time: {time.time() - start_time:.2f} seconds")
202
 
@@ -208,7 +206,7 @@ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p
208
  gc.collect()
209
 
210
  def clear_inputs():
211
- return "", 3.0, 300, 0.95, 1.0, 45, 750
212
 
213
  # 7) CUSTOM CSS (Unchanged)
214
  css = """
@@ -379,7 +377,7 @@ with gr.Blocks(css=css) as demo:
379
  label="Top-K Sampling",
380
  minimum=10,
381
  maximum=500,
382
- value=300,
383
  step=10,
384
  info="Limits sampling to the top k most likely tokens. Higher values increase diversity."
385
  )
@@ -387,7 +385,7 @@ with gr.Blocks(css=css) as demo:
387
  label="Top-P Sampling (Nucleus Sampling)",
388
  minimum=0.0,
389
  maximum=1.0,
390
- value=0.95,
391
  step=0.1,
392
  info="Keeps tokens with cumulative probability above p. Higher values increase diversity."
393
  )
@@ -395,7 +393,7 @@ with gr.Blocks(css=css) as demo:
395
  label="Temperature",
396
  minimum=0.1,
397
  maximum=2.0,
398
- value=1.0,
399
  step=0.1,
400
  info="Controls randomness. Higher values make output more diverse but less predictable."
401
  )
 
10
  from pydub import AudioSegment
11
  from audiocraft.models import MusicGen
12
  from torch.cuda.amp import autocast
 
13
 
14
  # Set PYTORCH_CUDA_ALLOC_CONF to manage memory fragmentation
15
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
 
81
  delayed = delayed.set_frame_rate(segment.frame_rate)
82
  return segment.overlay(delayed, position=20)
83
 
84
+ def apply_reverb(segment):
85
+ # Simulate reverb by overlaying multiple delayed copies with decreasing amplitude
86
+ reverb_segment = segment
87
+ for delay_ms, gain_db in [(50, -10), (100, -15), (150, -20)]:
88
+ delayed = segment - gain_db
89
+ delayed = delayed.set_frame_rate(segment.frame_rate)
90
+ reverb_segment = reverb_segment.overlay(delayed, position=delay_ms)
91
+ return reverb_segment
92
+
93
  def apply_eq(segment):
94
  # Adjusted EQ for clarity in midrange
95
  segment = segment.low_pass_filter(8000)
96
  segment = segment.high_pass_filter(80)
 
97
  segment = segment.equalizer(frequency=1000, gain=2, q=1.0)
98
  return segment
99
 
 
 
 
 
100
  def apply_limiter(segment, max_db=-3.0):
101
  if segment.dBFS > max_db:
102
  segment = segment - (segment.dBFS - max_db)
103
  return segment
104
 
105
  def apply_final_gain(segment, target_db=-12.0):
 
106
  gain_adjustment = target_db - segment.dBFS
107
  return segment + gain_adjustment
108
 
 
115
  start_time = time.time()
116
 
117
  total_duration = min(max(total_duration, 10), 90)
118
+ chunk_duration = total_duration # Single chunk to minimize overhead
119
+ num_chunks = 1 # Single chunk generation
 
 
 
 
120
 
121
  audio_chunks = []
122
  sample_rate = musicgen_model.sample_rate
123
 
124
  for i in range(num_chunks):
125
+ chunk_prompt = instrumental_prompt
126
  print(f"Generating chunk {i+1}/{num_chunks} on GPU (prompt: {chunk_prompt})...")
127
  musicgen_model.set_generation_params(
128
+ duration=chunk_duration,
129
  use_sampling=True,
130
+ top_k=250, # Reduced for faster generation
131
+ top_p=0.9, # Adjusted for balance
132
+ temperature=0.9, # Slightly reduced for consistency
133
  cfg_coef=cfg_scale
134
  )
135
 
 
153
  if audio_chunk.shape[0] != 2:
154
  raise ValueError(f"Expected stereo audio with shape (2, samples), got shape {audio_chunk.shape}")
155
 
156
+ # Process in memory using pydub without intermediate file I/O
157
+ audio_array = audio_chunk.numpy()
158
+ audio_array = (audio_array * 32767).astype(np.int16) # Convert to 16-bit PCM
159
+ segment = AudioSegment(
160
+ audio_array.tobytes(),
161
+ frame_rate=sample_rate,
162
+ sample_width=2, # 16-bit
163
+ channels=2
164
+ )
165
+ audio_chunks.append(segment)
166
 
167
  torch.cuda.empty_cache()
168
  gc.collect()
 
170
  print_resource_usage(f"After Chunk {i+1} Generation")
171
 
172
  print("Combining audio chunks...")
173
+ final_segment = audio_chunks[0]
174
  for i in range(1, len(audio_chunks)):
175
+ next_segment = audio_chunks[i]
176
  next_segment = next_segment + 1
177
  final_segment = final_segment.append(next_segment, crossfade=crossfade_duration)
178
 
 
195
  )
196
  print(f"Saved final audio to {mp3_path}")
197
 
 
 
 
198
  print_resource_usage("After Final Generation")
199
  print(f"Total Generation Time: {time.time() - start_time:.2f} seconds")
200
 
 
206
  gc.collect()
207
 
208
  def clear_inputs():
209
+ return "", 3.0, 250, 0.9, 0.9, 45, 750
210
 
211
  # 7) CUSTOM CSS (Unchanged)
212
  css = """
 
377
  label="Top-K Sampling",
378
  minimum=10,
379
  maximum=500,
380
+ value=250,
381
  step=10,
382
  info="Limits sampling to the top k most likely tokens. Higher values increase diversity."
383
  )
 
385
  label="Top-P Sampling (Nucleus Sampling)",
386
  minimum=0.0,
387
  maximum=1.0,
388
+ value=0.9,
389
  step=0.1,
390
  info="Keeps tokens with cumulative probability above p. Higher values increase diversity."
391
  )
 
393
  label="Temperature",
394
  minimum=0.1,
395
  maximum=2.0,
396
+ value=0.9,
397
  step=0.1,
398
  info="Controls randomness. Higher values make output more diverse but less predictable."
399
  )