|
|
import torch
|
|
|
from fish_speech.models.fish_speech import FishSpeech
|
|
|
from fish_speech.inference import infer
|
|
|
import io
|
|
|
import base64
|
|
|
import soundfile as sf
|
|
|
|
|
|
|
|
|
model = FishSpeech.from_pretrained('fishaudio/fish-speech-1.5')
|
|
|
|
|
|
def predict(inputs: dict):
|
|
|
text = inputs.get('inputs', 'Hello world')
|
|
|
|
|
|
if "[singing]" in text.lower():
|
|
|
mode = "singing"
|
|
|
text = text.replace("[singing]", "")
|
|
|
else:
|
|
|
mode = "speech"
|
|
|
|
|
|
|
|
|
audio = infer(model, text, mode=mode)
|
|
|
|
|
|
|
|
|
buffer = io.BytesIO()
|
|
|
sf.write(buffer, audio.cpu().numpy(), 24000, format='WAV')
|
|
|
audio_b64 = base64.b64encode(buffer.getvalue()).decode()
|
|
|
|
|
|
return {"audio": audio_b64}
|
|
|
|
|
|
def query(payload):
|
|
|
return predict(payload)
|
|
|
|