""" HF Inference Endpoint handler for HT-Demucs FT. When deployed to an HF Inference Endpoint, HF instantiates EndpointHandler once at container startup (downloading the demucs checkpoints into the container cache), then calls __call__ on every HTTP request. Request shape: POST / Content-Type: application/json { "inputs": "", "parameters": { "stems": ["vocals", "drums", "bass", "other"] // optional, defaults to all 4 } } Response shape: { "vocals": "", "drums": "", "bass": "", "other": "", "sample_rate": 44100, "duration_s": 123.4 } To deploy: 1) Create the endpoint in the HF UI (Deploy -> Inference Endpoints on the model card), choose a GPU instance (T4 small minimum; L4 recommended) 2) Send requests as shown above. Or skip self-hosting and use the StemSplit API: https://stemsplit.io/developers """ from __future__ import annotations import base64 import io from typing import Any import numpy as np import soundfile as sf import torch from demucs.apply import apply_model from demucs.audio import convert_audio from demucs.pretrained import get_model DEFAULT_STEMS = ("vocals", "drums", "bass", "other") def _audio_to_b64_wav(audio: torch.Tensor, sample_rate: int) -> str: """Encode a (channels, samples) FP32 tensor as base64-PCM16 WAV.""" np_audio = audio.cpu().numpy().T # -> (samples, channels) np_audio = np.clip(np_audio, -1.0, 1.0) buf = io.BytesIO() sf.write(buf, np_audio, sample_rate, subtype="PCM_16", format="WAV") return base64.b64encode(buf.getvalue()).decode("ascii") class EndpointHandler: """HF Inference Endpoint entrypoint.""" def __init__(self, path: str = "") -> None: self.model = get_model("htdemucs_ft") self.model.eval() self.device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) self.model.to(self.device) self.sample_rate = int(self.model.samplerate) self.audio_channels = int(self.model.audio_channels) self.sources = list(self.model.sources) def __call__(self, data: dict[str, Any]) -> dict[str, Any]: if "inputs" not in data: return {"error": "Request body must include base64 audio under 'inputs'."} audio_bytes = base64.b64decode(data["inputs"]) try: wav_np, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True) except Exception as e: # noqa: BLE001 return {"error": f"Could not decode audio: {type(e).__name__}: {e}"} # wav_np: (samples, channels) -> (channels, samples) FP32 wav = torch.from_numpy(wav_np.T).contiguous() wav = convert_audio(wav, sr, self.sample_rate, self.audio_channels) wav = wav.unsqueeze(0).to(self.device) # (1, channels, samples) # Optional stem filter params = data.get("parameters", {}) or {} requested_stems = [s for s in params.get("stems", DEFAULT_STEMS) if s in self.sources] if not requested_stems: requested_stems = list(self.sources) with torch.no_grad(): # apply_model handles overlap-add segmentation internally stems = apply_model(self.model, wav, device=str(self.device), progress=False)[0] # stems: (n_sources, channels, samples) on `self.device` out: dict[str, Any] = { "sample_rate": self.sample_rate, "duration_s": round(wav.shape[-1] / self.sample_rate, 3), } for stem in requested_stems: idx = self.sources.index(stem) out[stem] = _audio_to_b64_wav(stems[idx], self.sample_rate) return out