Carley1234 commited on
Commit
a0240fe
·
verified ·
1 Parent(s): 9fae321

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -89,7 +89,16 @@ async def generate_effect(job_id: str, prompt: str = Form(...), duration: int =
89
 
90
  def run_inference():
91
  with torch.no_grad():
92
- return audio_pipe(prompt, forward_params={"max_new_tokens": max_tokens})
 
 
 
 
 
 
 
 
 
93
 
94
  result = await asyncio.to_thread(run_inference)
95
 
@@ -100,15 +109,20 @@ async def generate_effect(job_id: str, prompt: str = Form(...), duration: int =
100
  if isinstance(audio_data, torch.Tensor):
101
  audio_data = audio_data.cpu().numpy()
102
 
103
- # Squeeze if necessary
104
- audio_data = np.squeeze(audio_data)
105
 
106
- # Normalize audio to -1.0 to 1.0 range if it isn't already
 
 
 
 
 
107
  max_val = np.abs(audio_data).max()
108
  if max_val > 0:
109
- audio_data = audio_data / max_val
110
 
111
- # Convert to 16-bit PCM (standard WAV format) for better quality/compatibility
112
  audio_data = (audio_data * 32767).astype(np.int16)
113
 
114
  wav_buf = io.BytesIO()
 
89
 
90
  def run_inference():
91
  with torch.no_grad():
92
+ # Enabling sampling for AudioGen-small as well
93
+ return audio_pipe(
94
+ prompt,
95
+ forward_params={
96
+ "max_new_tokens": max_tokens,
97
+ "do_sample": True,
98
+ "temperature": 1.0,
99
+ "top_k": 250
100
+ }
101
+ )
102
 
103
  result = await asyncio.to_thread(run_inference)
104
 
 
109
  if isinstance(audio_data, torch.Tensor):
110
  audio_data = audio_data.cpu().numpy()
111
 
112
+ # Clean data and handle dimensions
113
+ audio_data = np.nan_to_num(audio_data)
114
 
115
+ if audio_data.ndim > 1:
116
+ audio_data = audio_data[0]
117
+ if audio_data.ndim > 1:
118
+ audio_data = np.mean(audio_data, axis=0)
119
+
120
+ # Normalize audio to -1.0 to 1.0 range
121
  max_val = np.abs(audio_data).max()
122
  if max_val > 0:
123
+ audio_data = audio_data / (max_val + 1e-6)
124
 
125
+ # Convert to 16-bit PCM (standard WAV format)
126
  audio_data = (audio_data * 32767).astype(np.int16)
127
 
128
  wav_buf = io.BytesIO()