| """ |
| HF Inference Endpoint handler for HT-Demucs FT. |
| |
| When deployed to an HF Inference Endpoint, HF instantiates EndpointHandler |
| once at container startup (downloading the demucs checkpoints into the |
| container cache), then calls __call__ on every HTTP request. |
| |
| Request shape: |
| POST / |
| Content-Type: application/json |
| { |
| "inputs": "<base64-encoded audio bytes; any libsndfile-readable format>", |
| "parameters": { |
| "stems": ["vocals", "drums", "bass", "other"] // optional, defaults to all 4 |
| } |
| } |
| |
| Response shape: |
| { |
| "vocals": "<base64 WAV>", |
| "drums": "<base64 WAV>", |
| "bass": "<base64 WAV>", |
| "other": "<base64 WAV>", |
| "sample_rate": 44100, |
| "duration_s": 123.4 |
| } |
| |
| To deploy: |
| 1) Create the endpoint in the HF UI (Deploy -> Inference Endpoints on the |
| model card), choose a GPU instance (T4 small minimum; L4 recommended) |
| 2) Send requests as shown above. |
| |
| Or skip self-hosting and use the StemSplit API: |
| https://stemsplit.io/developers |
| """ |
| from __future__ import annotations |
|
|
| import base64 |
| import io |
| from typing import Any |
|
|
| import numpy as np |
| import soundfile as sf |
| import torch |
| from demucs.apply import apply_model |
| from demucs.audio import convert_audio |
| from demucs.pretrained import get_model |
|
|
| DEFAULT_STEMS = ("vocals", "drums", "bass", "other") |
|
|
|
|
| def _audio_to_b64_wav(audio: torch.Tensor, sample_rate: int) -> str: |
| """Encode a (channels, samples) FP32 tensor as base64-PCM16 WAV.""" |
| np_audio = audio.cpu().numpy().T |
| np_audio = np.clip(np_audio, -1.0, 1.0) |
| buf = io.BytesIO() |
| sf.write(buf, np_audio, sample_rate, subtype="PCM_16", format="WAV") |
| return base64.b64encode(buf.getvalue()).decode("ascii") |
|
|
|
|
| class EndpointHandler: |
| """HF Inference Endpoint entrypoint.""" |
|
|
| def __init__(self, path: str = "") -> None: |
| self.model = get_model("htdemucs_ft") |
| self.model.eval() |
| self.device = torch.device( |
| "cuda" if torch.cuda.is_available() else |
| "mps" if torch.backends.mps.is_available() else |
| "cpu" |
| ) |
| self.model.to(self.device) |
| self.sample_rate = int(self.model.samplerate) |
| self.audio_channels = int(self.model.audio_channels) |
| self.sources = list(self.model.sources) |
|
|
| def __call__(self, data: dict[str, Any]) -> dict[str, Any]: |
| if "inputs" not in data: |
| return {"error": "Request body must include base64 audio under 'inputs'."} |
|
|
| audio_bytes = base64.b64decode(data["inputs"]) |
| try: |
| wav_np, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True) |
| except Exception as e: |
| return {"error": f"Could not decode audio: {type(e).__name__}: {e}"} |
|
|
| |
| wav = torch.from_numpy(wav_np.T).contiguous() |
| wav = convert_audio(wav, sr, self.sample_rate, self.audio_channels) |
| wav = wav.unsqueeze(0).to(self.device) |
|
|
| |
| params = data.get("parameters", {}) or {} |
| requested_stems = [s for s in params.get("stems", DEFAULT_STEMS) if s in self.sources] |
| if not requested_stems: |
| requested_stems = list(self.sources) |
|
|
| with torch.no_grad(): |
| |
| stems = apply_model(self.model, wav, device=str(self.device), progress=False)[0] |
| |
|
|
| out: dict[str, Any] = { |
| "sample_rate": self.sample_rate, |
| "duration_s": round(wav.shape[-1] / self.sample_rate, 3), |
| } |
| for stem in requested_stems: |
| idx = self.sources.index(stem) |
| out[stem] = _audio_to_b64_wav(stems[idx], self.sample_rate) |
| return out |
|
|