import base64 import io from typing import Any import numpy as np import soundfile as sf import torch import torchaudio # SpeechBrain 1.0.x still expects this legacy torchaudio helper. if not hasattr(torchaudio, "list_audio_backends"): torchaudio.list_audio_backends = lambda: ["soundfile"] from speechbrain.inference.separation import SepformerSeparation TARGET_SAMPLE_RATE = 8000 class EndpointHandler: def __init__(self, path: str = ""): device = "cuda" if torch.cuda.is_available() else "cpu" self.model = SepformerSeparation.from_hparams( source=path or ".", savedir=path or ".", run_opts={"device": device}, ) def __call__(self, data: Any) -> dict: audio_bytes = self._extract_audio_bytes(data) waveform, sample_rate = self._load_audio(audio_bytes) with torch.no_grad(): est_sources = self.model.separate_batch(waveform.unsqueeze(0)) est_sources = est_sources.squeeze(0).detach().cpu() if est_sources.ndim == 1: est_sources = est_sources.unsqueeze(-1) outputs = [] for idx in range(est_sources.shape[-1]): source = est_sources[:, idx].numpy() buffer = io.BytesIO() sf.write(buffer, source, TARGET_SAMPLE_RATE, format="WAV") outputs.append( { "speaker": idx, "audio_base64": base64.b64encode(buffer.getvalue()).decode("utf-8"), "sample_rate": TARGET_SAMPLE_RATE, "mime_type": "audio/wav", } ) return { "num_speakers": len(outputs), "sources": outputs, } def _extract_audio_bytes(self, data: Any) -> bytes: if isinstance(data, (bytes, bytearray)): return bytes(data) if isinstance(data, dict): payload = data.get("inputs", data) if isinstance(payload, (bytes, bytearray)): return bytes(payload) if isinstance(payload, str): return self._decode_base64_audio(payload) if isinstance(payload, dict): for key in ("audio", "audio_base64", "data"): value = payload.get(key) if isinstance(value, str): return self._decode_base64_audio(value) raise ValueError("Unsupported request format. Send raw audio bytes or a JSON body with base64 audio.") def _decode_base64_audio(self, value: str) -> bytes: if "," in value and value.startswith("data:"): value = value.split(",", 1)[1] return base64.b64decode(value) def _load_audio(self, audio_bytes: bytes) -> tuple[torch.Tensor, int]: waveform, sample_rate = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True) waveform = torch.from_numpy(waveform.T) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) if sample_rate != TARGET_SAMPLE_RATE: resampler = torchaudio.transforms.Resample(sample_rate, TARGET_SAMPLE_RATE) waveform = resampler(waveform) return waveform.squeeze(0), TARGET_SAMPLE_RATE