Carley1234 commited on
Commit
8f336f4
·
verified ·
1 Parent(s): 8727b92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import asyncio
4
  import base64
5
  import datetime
6
  import torch
 
7
  import scipy.io.wavfile
8
  from fastapi import FastAPI, HTTPException, Form
9
  from fastapi.middleware.cors import CORSMiddleware
@@ -95,8 +96,15 @@ async def generate_effect(job_id: str, prompt: str = Form(...), duration: int =
95
  sampling_rate = result["sampling_rate"]
96
  audio_data = result["audio"]
97
 
 
 
 
 
 
 
 
98
  wav_buf = io.BytesIO()
99
- scipy.io.wavfile.write(wav_buf, rate=sampling_rate, data=audio_data[0])
100
  wav_buf.seek(0)
101
 
102
  audio_base64 = base64.b64encode(wav_buf.read()).decode('utf-8')
 
4
  import base64
5
  import datetime
6
  import torch
7
+ import numpy as np
8
  import scipy.io.wavfile
9
  from fastapi import FastAPI, HTTPException, Form
10
  from fastapi.middleware.cors import CORSMiddleware
 
96
  sampling_rate = result["sampling_rate"]
97
  audio_data = result["audio"]
98
 
99
+ # Ensure audio_data is a numpy array and has correct type for scipy
100
+ if isinstance(audio_data, torch.Tensor):
101
+ audio_data = audio_data.cpu().numpy()
102
+
103
+ # Squeeze if necessary
104
+ audio_data = np.squeeze(audio_data)
105
+
106
  wav_buf = io.BytesIO()
107
+ scipy.io.wavfile.write(wav_buf, rate=sampling_rate, data=audio_data)
108
  wav_buf.seek(0)
109
 
110
  audio_base64 = base64.b64encode(wav_buf.read()).decode('utf-8')