|
|
from typing import Dict, List, Any |
|
|
from kokoro import KPipeline |
|
|
from IPython.display import display, Audio |
|
|
import soundfile as sf |
|
|
import torch |
|
|
import io |
|
|
import os |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self): |
|
|
self.pipeline = KPipeline(lang_code='a') |
|
|
|
|
|
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
text = inputs["text"] |
|
|
voice = inputs["voice"] |
|
|
|
|
|
audio_segments = [] |
|
|
|
|
|
generator = self.pipeline(text, voice) |
|
|
for i, (gs, ps, audio) in enumerate(generator): |
|
|
|
|
|
filename = f"{i}.wav" |
|
|
sf.write(filename, audio, 24000) |
|
|
audio_segments.append(audio) |
|
|
|
|
|
|
|
|
full_audio = torch.cat([torch.tensor(a) for a in audio_segments]) |
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
sf.write(buffer, full_audio.numpy(), 24000, format='WAV') |
|
|
buffer.seek(0) |
|
|
|
|
|
|
|
|
for i in range(len(audio_segments)): |
|
|
try: |
|
|
os.remove(f"{i}.wav") |
|
|
except FileNotFoundError: |
|
|
pass |
|
|
|
|
|
return { |
|
|
"audio": buffer.read() |
|
|
} |