| | import base64 |
| | import io |
| | from typing import Any |
| |
|
| | import numpy as np |
| | import soundfile as sf |
| | import torch |
| | import torchaudio |
| |
|
| |
|
| | |
| | if not hasattr(torchaudio, "list_audio_backends"): |
| | torchaudio.list_audio_backends = lambda: ["soundfile"] |
| |
|
| | from speechbrain.inference.separation import SepformerSeparation |
| |
|
| |
|
| | TARGET_SAMPLE_RATE = 8000 |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.model = SepformerSeparation.from_hparams( |
| | source=path or ".", |
| | savedir=path or ".", |
| | run_opts={"device": device}, |
| | ) |
| |
|
| | def __call__(self, data: Any) -> dict: |
| | audio_bytes = self._extract_audio_bytes(data) |
| | waveform, sample_rate = self._load_audio(audio_bytes) |
| |
|
| | with torch.no_grad(): |
| | est_sources = self.model.separate_batch(waveform.unsqueeze(0)) |
| |
|
| | est_sources = est_sources.squeeze(0).detach().cpu() |
| | if est_sources.ndim == 1: |
| | est_sources = est_sources.unsqueeze(-1) |
| |
|
| | outputs = [] |
| | for idx in range(est_sources.shape[-1]): |
| | source = est_sources[:, idx].numpy() |
| | buffer = io.BytesIO() |
| | sf.write(buffer, source, TARGET_SAMPLE_RATE, format="WAV") |
| | outputs.append( |
| | { |
| | "speaker": idx, |
| | "audio_base64": base64.b64encode(buffer.getvalue()).decode("utf-8"), |
| | "sample_rate": TARGET_SAMPLE_RATE, |
| | "mime_type": "audio/wav", |
| | } |
| | ) |
| |
|
| | return { |
| | "num_speakers": len(outputs), |
| | "sources": outputs, |
| | } |
| |
|
| | def _extract_audio_bytes(self, data: Any) -> bytes: |
| | if isinstance(data, (bytes, bytearray)): |
| | return bytes(data) |
| |
|
| | if isinstance(data, dict): |
| | payload = data.get("inputs", data) |
| |
|
| | if isinstance(payload, (bytes, bytearray)): |
| | return bytes(payload) |
| |
|
| | if isinstance(payload, str): |
| | return self._decode_base64_audio(payload) |
| |
|
| | if isinstance(payload, dict): |
| | for key in ("audio", "audio_base64", "data"): |
| | value = payload.get(key) |
| | if isinstance(value, str): |
| | return self._decode_base64_audio(value) |
| |
|
| | raise ValueError("Unsupported request format. Send raw audio bytes or a JSON body with base64 audio.") |
| |
|
| | def _decode_base64_audio(self, value: str) -> bytes: |
| | if "," in value and value.startswith("data:"): |
| | value = value.split(",", 1)[1] |
| | return base64.b64decode(value) |
| |
|
| | def _load_audio(self, audio_bytes: bytes) -> tuple[torch.Tensor, int]: |
| | waveform, sample_rate = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True) |
| | waveform = torch.from_numpy(waveform.T) |
| |
|
| | if waveform.shape[0] > 1: |
| | waveform = waveform.mean(dim=0, keepdim=True) |
| |
|
| | if sample_rate != TARGET_SAMPLE_RATE: |
| | resampler = torchaudio.transforms.Resample(sample_rate, TARGET_SAMPLE_RATE) |
| | waveform = resampler(waveform) |
| |
|
| | return waveform.squeeze(0), TARGET_SAMPLE_RATE |
| |
|