| """ |
| HF Inference Endpoint handler for the HT-Demucs FT **other** specialist. |
| |
| This repo ships only sub-model 2 of the 4-bag htdemucs_ft ensemble |
| — the one trained to extract `other`. ~160 MB on disk and ~1/4 the inference |
| cost of the full bag, with the same per-stem quality as our v1.1 benchmark |
| (median other SDR = 6.34 dB). |
| |
| If you need all 4 stems in one request, use the full ensemble: |
| https://huggingface.co/StemSplitio/htdemucs-ft-pytorch |
| |
| Request shape: |
| POST / |
| Content-Type: application/json |
| { "inputs": "<base64-encoded audio bytes>" } |
| |
| Response shape: |
| { "other": "<base64 WAV>", "sample_rate": 44100, "duration_s": 123.4 } |
| """ |
| from __future__ import annotations |
|
|
| import base64 |
| import io |
| from typing import Any |
|
|
| import numpy as np |
| import soundfile as sf |
| import torch |
| from demucs.apply import apply_model |
| from demucs.audio import convert_audio |
| from demucs.pretrained import get_model |
|
|
| |
| BAG_INDEX = 2 |
| TARGET_STEM = "other" |
|
|
|
|
| def _audio_to_b64_wav(audio: torch.Tensor, sample_rate: int) -> str: |
| np_audio = np.clip(audio.cpu().numpy().T, -1.0, 1.0) |
| buf = io.BytesIO() |
| sf.write(buf, np_audio, sample_rate, subtype="PCM_16", format="WAV") |
| return base64.b64encode(buf.getvalue()).decode("ascii") |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = "") -> None: |
| |
| |
| bag = get_model("htdemucs_ft") |
| self.model = bag.models[BAG_INDEX] |
| self.model.eval() |
| self.device = torch.device( |
| "cuda" if torch.cuda.is_available() else |
| "mps" if torch.backends.mps.is_available() else |
| "cpu" |
| ) |
| self.model.to(self.device) |
| self.sample_rate = int(bag.samplerate) |
| self.audio_channels = int(bag.audio_channels) |
| self.sources = list(bag.sources) |
| self.target_index = self.sources.index(TARGET_STEM) |
|
|
| def __call__(self, data: dict[str, Any]) -> dict[str, Any]: |
| if "inputs" not in data: |
| return {"error": "Request body must include base64 audio under 'inputs'."} |
|
|
| try: |
| audio_bytes = base64.b64decode(data["inputs"]) |
| wav_np, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True) |
| except Exception as e: |
| return {"error": f"Could not decode audio: {type(e).__name__}: {e}"} |
|
|
| wav = torch.from_numpy(wav_np.T).contiguous() |
| wav = convert_audio(wav, sr, self.sample_rate, self.audio_channels) |
| wav = wav.unsqueeze(0).to(self.device) |
|
|
| with torch.no_grad(): |
| |
| |
| stems = apply_model(self.model, wav, device=str(self.device), progress=False)[0] |
| |
| |
| |
|
|
| return { |
| "other": _audio_to_b64_wav(stems[self.target_index], self.sample_rate), |
| "sample_rate": self.sample_rate, |
| "duration_s": round(wav.shape[-1] / self.sample_rate, 3), |
| } |
|
|