| | import torch |
| | from fish_speech.models.fish_speech import FishSpeech |
| | from fish_speech.inference import infer |
| | import io |
| | import base64 |
| | import soundfile as sf |
| |
|
| | |
| | model = FishSpeech.from_pretrained('fishaudio/fish-speech-1.5') |
| |
|
| | def predict(inputs: dict): |
| | text = inputs.get('inputs', 'Hello world') |
| | |
| | if "[singing]" in text.lower(): |
| | mode = "singing" |
| | text = text.replace("[singing]", "") |
| | else: |
| | mode = "speech" |
| |
|
| | |
| | audio = infer(model, text, mode=mode) |
| |
|
| | |
| | buffer = io.BytesIO() |
| | sf.write(buffer, audio.cpu().numpy(), 24000, format='WAV') |
| | audio_b64 = base64.b64encode(buffer.getvalue()).decode() |
| |
|
| | return {"audio": audio_b64} |
| |
|
| | def query(payload): |
| | return predict(payload) |