mms-tts-eng / handler.py
oyemade's picture
Upload 2 files
1689b5f verified
raw
history blame
1.06 kB
# handler.py
from typing import Dict, Any
from transformers import pipeline
class EndpointHandler:
def __init__(self, model_path: str = ""):
"""
Load the MMS-TTS pipeline once at startup.
transformers>=4.33.0 is required for MMS-TTS support.
"""
self.tts = pipeline(
task="text-to-speech",
model=model_path,
device=0,
# device_map="auto" # optional: to leverage GPU if available
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
data: {"inputs": "<text to synthesize>"}
Returns: {"wav": <binary audio>, "sampling_rate": <int>}
"""
text = data.get("inputs", "")
# Run TTS; returns a dict with "wav" and "sampling_rate"
result = self.tts(text)
audio = result["audio"]
return {
"array": audio.T.tolist(), # transpose if needed to fix ushort format:contentReference[oaicite:6]{index=6}
"sampling_rate": result["sampling_rate"],
}