File size: 3,912 Bytes
71605c4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | """
HF Inference Endpoint handler for HT-Demucs FT.
When deployed to an HF Inference Endpoint, HF instantiates EndpointHandler
once at container startup (downloading the demucs checkpoints into the
container cache), then calls __call__ on every HTTP request.
Request shape:
POST /
Content-Type: application/json
{
"inputs": "<base64-encoded audio bytes; any libsndfile-readable format>",
"parameters": {
"stems": ["vocals", "drums", "bass", "other"] // optional, defaults to all 4
}
}
Response shape:
{
"vocals": "<base64 WAV>",
"drums": "<base64 WAV>",
"bass": "<base64 WAV>",
"other": "<base64 WAV>",
"sample_rate": 44100,
"duration_s": 123.4
}
To deploy:
1) Create the endpoint in the HF UI (Deploy -> Inference Endpoints on the
model card), choose a GPU instance (T4 small minimum; L4 recommended)
2) Send requests as shown above.
Or skip self-hosting and use the StemSplit API:
https://stemsplit.io/developers
"""
from __future__ import annotations
import base64
import io
from typing import Any
import numpy as np
import soundfile as sf
import torch
from demucs.apply import apply_model
from demucs.audio import convert_audio
from demucs.pretrained import get_model
DEFAULT_STEMS = ("vocals", "drums", "bass", "other")
def _audio_to_b64_wav(audio: torch.Tensor, sample_rate: int) -> str:
"""Encode a (channels, samples) FP32 tensor as base64-PCM16 WAV."""
np_audio = audio.cpu().numpy().T # -> (samples, channels)
np_audio = np.clip(np_audio, -1.0, 1.0)
buf = io.BytesIO()
sf.write(buf, np_audio, sample_rate, subtype="PCM_16", format="WAV")
return base64.b64encode(buf.getvalue()).decode("ascii")
class EndpointHandler:
"""HF Inference Endpoint entrypoint."""
def __init__(self, path: str = "") -> None:
self.model = get_model("htdemucs_ft")
self.model.eval()
self.device = torch.device(
"cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else
"cpu"
)
self.model.to(self.device)
self.sample_rate = int(self.model.samplerate)
self.audio_channels = int(self.model.audio_channels)
self.sources = list(self.model.sources)
def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
if "inputs" not in data:
return {"error": "Request body must include base64 audio under 'inputs'."}
audio_bytes = base64.b64decode(data["inputs"])
try:
wav_np, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True)
except Exception as e: # noqa: BLE001
return {"error": f"Could not decode audio: {type(e).__name__}: {e}"}
# wav_np: (samples, channels) -> (channels, samples) FP32
wav = torch.from_numpy(wav_np.T).contiguous()
wav = convert_audio(wav, sr, self.sample_rate, self.audio_channels)
wav = wav.unsqueeze(0).to(self.device) # (1, channels, samples)
# Optional stem filter
params = data.get("parameters", {}) or {}
requested_stems = [s for s in params.get("stems", DEFAULT_STEMS) if s in self.sources]
if not requested_stems:
requested_stems = list(self.sources)
with torch.no_grad():
# apply_model handles overlap-add segmentation internally
stems = apply_model(self.model, wav, device=str(self.device), progress=False)[0]
# stems: (n_sources, channels, samples) on `self.device`
out: dict[str, Any] = {
"sample_rate": self.sample_rate,
"duration_s": round(wav.shape[-1] / self.sample_rate, 3),
}
for stem in requested_stems:
idx = self.sources.index(stem)
out[stem] = _audio_to_b64_wav(stems[idx], self.sample_rate)
return out
|