Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import asyncio
|
|
| 4 |
import base64
|
| 5 |
import datetime
|
| 6 |
import torch
|
|
|
|
| 7 |
import scipy.io.wavfile
|
| 8 |
from fastapi import FastAPI, HTTPException, Form
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
|
@@ -95,8 +96,15 @@ async def generate_effect(job_id: str, prompt: str = Form(...), duration: int =
|
|
| 95 |
sampling_rate = result["sampling_rate"]
|
| 96 |
audio_data = result["audio"]
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
wav_buf = io.BytesIO()
|
| 99 |
-
scipy.io.wavfile.write(wav_buf, rate=sampling_rate, data=audio_data
|
| 100 |
wav_buf.seek(0)
|
| 101 |
|
| 102 |
audio_base64 = base64.b64encode(wav_buf.read()).decode('utf-8')
|
|
|
|
| 4 |
import base64
|
| 5 |
import datetime
|
| 6 |
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
import scipy.io.wavfile
|
| 9 |
from fastapi import FastAPI, HTTPException, Form
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 96 |
sampling_rate = result["sampling_rate"]
|
| 97 |
audio_data = result["audio"]
|
| 98 |
|
| 99 |
+
# Ensure audio_data is a numpy array and has correct type for scipy
|
| 100 |
+
if isinstance(audio_data, torch.Tensor):
|
| 101 |
+
audio_data = audio_data.cpu().numpy()
|
| 102 |
+
|
| 103 |
+
# Squeeze if necessary
|
| 104 |
+
audio_data = np.squeeze(audio_data)
|
| 105 |
+
|
| 106 |
wav_buf = io.BytesIO()
|
| 107 |
+
scipy.io.wavfile.write(wav_buf, rate=sampling_rate, data=audio_data)
|
| 108 |
wav_buf.seek(0)
|
| 109 |
|
| 110 |
audio_base64 = base64.b64encode(wav_buf.read()).decode('utf-8')
|