| | import torchaudio as ta |
| | from chatterbox.tts import ChatterboxTTS |
| | from typing import Dict, Any, List |
| | import soundfile as sf |
| | import io |
| | import base64 |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | try: |
| | self.model = ChatterboxTTS.from_pretrained(device="cuda") |
| | except Exception as e: |
| | raise RuntimeError(f"[ERROR] Failed to load model: {e}") |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | try: |
| | inputs = data.get("inputs", {}) |
| | text = inputs.get("text") |
| | exaggeration = inputs.get("exaggeration", 0.3) |
| | cfg_weight = inputs.get("cfg_weight", 0.5) |
| | print(exaggeration, cfg_weight) |
| |
|
| | AUDIO_PROMPT_PATH="arjun_das_output_audio.mp3" |
| | wav = self.model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH, exaggeration = exaggeration, cfg_weight=cfg_weight) |
| | |
| | |
| | |
| | buffer = io.BytesIO() |
| | sf.write(buffer, wav.cpu().numpy().T, self.model.sr, format='WAV') |
| | buffer.seek(0) |
| |
|
| | |
| | audio_base64 = base64.b64encode(buffer.read()).decode('utf-8') |
| |
|
| | return [{"audio_base64": audio_base64}] |
| |
|
| |
|
| | except Exception as e: |
| | print(f"[ERROR] Inference failed: {e}") |
| | return [{"error": str(e)}] |
| |
|