Codex Bot commited on
Commit
abffd77
·
1 Parent(s): 4298bdc

Add custom inference endpoint handler

Browse files
Files changed (2) hide show
  1. handler.py +97 -0
  2. requirements.txt +5 -0
handler.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import torch
8
+ import torchaudio
9
+
10
+
11
+ # SpeechBrain 1.0.x still expects this legacy torchaudio helper.
12
+ if not hasattr(torchaudio, "list_audio_backends"):
13
+ torchaudio.list_audio_backends = lambda: ["soundfile"]
14
+
15
+ from speechbrain.inference.separation import SepformerSeparation
16
+
17
+
18
+ TARGET_SAMPLE_RATE = 8000
19
+
20
+
21
+ class EndpointHandler:
22
+ def __init__(self, path: str = ""):
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ self.model = SepformerSeparation.from_hparams(
25
+ source=path or ".",
26
+ savedir=path or ".",
27
+ run_opts={"device": device},
28
+ )
29
+
30
+ def __call__(self, data: Any) -> dict:
31
+ audio_bytes = self._extract_audio_bytes(data)
32
+ waveform, sample_rate = self._load_audio(audio_bytes)
33
+
34
+ with torch.no_grad():
35
+ est_sources = self.model.separate_batch(waveform.unsqueeze(0))
36
+
37
+ est_sources = est_sources.squeeze(0).detach().cpu()
38
+ if est_sources.ndim == 1:
39
+ est_sources = est_sources.unsqueeze(-1)
40
+
41
+ outputs = []
42
+ for idx in range(est_sources.shape[-1]):
43
+ source = est_sources[:, idx].numpy()
44
+ buffer = io.BytesIO()
45
+ sf.write(buffer, source, TARGET_SAMPLE_RATE, format="WAV")
46
+ outputs.append(
47
+ {
48
+ "speaker": idx,
49
+ "audio_base64": base64.b64encode(buffer.getvalue()).decode("utf-8"),
50
+ "sample_rate": TARGET_SAMPLE_RATE,
51
+ "mime_type": "audio/wav",
52
+ }
53
+ )
54
+
55
+ return {
56
+ "num_speakers": len(outputs),
57
+ "sources": outputs,
58
+ }
59
+
60
+ def _extract_audio_bytes(self, data: Any) -> bytes:
61
+ if isinstance(data, (bytes, bytearray)):
62
+ return bytes(data)
63
+
64
+ if isinstance(data, dict):
65
+ payload = data.get("inputs", data)
66
+
67
+ if isinstance(payload, (bytes, bytearray)):
68
+ return bytes(payload)
69
+
70
+ if isinstance(payload, str):
71
+ return self._decode_base64_audio(payload)
72
+
73
+ if isinstance(payload, dict):
74
+ for key in ("audio", "audio_base64", "data"):
75
+ value = payload.get(key)
76
+ if isinstance(value, str):
77
+ return self._decode_base64_audio(value)
78
+
79
+ raise ValueError("Unsupported request format. Send raw audio bytes or a JSON body with base64 audio.")
80
+
81
+ def _decode_base64_audio(self, value: str) -> bytes:
82
+ if "," in value and value.startswith("data:"):
83
+ value = value.split(",", 1)[1]
84
+ return base64.b64decode(value)
85
+
86
+ def _load_audio(self, audio_bytes: bytes) -> tuple[torch.Tensor, int]:
87
+ waveform, sample_rate = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True)
88
+ waveform = torch.from_numpy(waveform.T)
89
+
90
+ if waveform.shape[0] > 1:
91
+ waveform = waveform.mean(dim=0, keepdim=True)
92
+
93
+ if sample_rate != TARGET_SAMPLE_RATE:
94
+ resampler = torchaudio.transforms.Resample(sample_rate, TARGET_SAMPLE_RATE)
95
+ waveform = resampler(waveform)
96
+
97
+ return waveform.squeeze(0), TARGET_SAMPLE_RATE
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ speechbrain==1.0.3
2
+ torch
3
+ torchaudio
4
+ soundfile
5
+ numpy