File size: 3,268 Bytes
abffd77 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import base64
import io
from typing import Any
import numpy as np
import soundfile as sf
import torch
import torchaudio
# SpeechBrain 1.0.x still expects this legacy torchaudio helper.
if not hasattr(torchaudio, "list_audio_backends"):
torchaudio.list_audio_backends = lambda: ["soundfile"]
from speechbrain.inference.separation import SepformerSeparation
TARGET_SAMPLE_RATE = 8000
class EndpointHandler:
def __init__(self, path: str = ""):
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SepformerSeparation.from_hparams(
source=path or ".",
savedir=path or ".",
run_opts={"device": device},
)
def __call__(self, data: Any) -> dict:
audio_bytes = self._extract_audio_bytes(data)
waveform, sample_rate = self._load_audio(audio_bytes)
with torch.no_grad():
est_sources = self.model.separate_batch(waveform.unsqueeze(0))
est_sources = est_sources.squeeze(0).detach().cpu()
if est_sources.ndim == 1:
est_sources = est_sources.unsqueeze(-1)
outputs = []
for idx in range(est_sources.shape[-1]):
source = est_sources[:, idx].numpy()
buffer = io.BytesIO()
sf.write(buffer, source, TARGET_SAMPLE_RATE, format="WAV")
outputs.append(
{
"speaker": idx,
"audio_base64": base64.b64encode(buffer.getvalue()).decode("utf-8"),
"sample_rate": TARGET_SAMPLE_RATE,
"mime_type": "audio/wav",
}
)
return {
"num_speakers": len(outputs),
"sources": outputs,
}
def _extract_audio_bytes(self, data: Any) -> bytes:
if isinstance(data, (bytes, bytearray)):
return bytes(data)
if isinstance(data, dict):
payload = data.get("inputs", data)
if isinstance(payload, (bytes, bytearray)):
return bytes(payload)
if isinstance(payload, str):
return self._decode_base64_audio(payload)
if isinstance(payload, dict):
for key in ("audio", "audio_base64", "data"):
value = payload.get(key)
if isinstance(value, str):
return self._decode_base64_audio(value)
raise ValueError("Unsupported request format. Send raw audio bytes or a JSON body with base64 audio.")
def _decode_base64_audio(self, value: str) -> bytes:
if "," in value and value.startswith("data:"):
value = value.split(",", 1)[1]
return base64.b64decode(value)
def _load_audio(self, audio_bytes: bytes) -> tuple[torch.Tensor, int]:
waveform, sample_rate = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True)
waveform = torch.from_numpy(waveform.T)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
if sample_rate != TARGET_SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(sample_rate, TARGET_SAMPLE_RATE)
waveform = resampler(waveform)
return waveform.squeeze(0), TARGET_SAMPLE_RATE
|