""" HF Inference Endpoint handler for the HT-Demucs FT **drums** specialist. This repo ships only sub-model 0 of the 4-bag htdemucs_ft ensemble — the one trained to extract `drums`. ~160 MB on disk and ~1/4 the inference cost of the full bag, with the same per-stem quality as our v1.1 benchmark (median drums SDR = 10.11 dB). If you need all 4 stems in one request, use the full ensemble: https://huggingface.co/StemSplitio/htdemucs-ft-pytorch Request shape: POST / Content-Type: application/json { "inputs": "" } Response shape: { "drums": "", "sample_rate": 44100, "duration_s": 123.4 } """ from __future__ import annotations import base64 import io from typing import Any import numpy as np import soundfile as sf import torch from demucs.apply import apply_model from demucs.audio import convert_audio from demucs.pretrained import get_model # Which sub-model of the htdemucs_ft bag to ship + which output index is ours. BAG_INDEX = 0 TARGET_STEM = "drums" def _audio_to_b64_wav(audio: torch.Tensor, sample_rate: int) -> str: np_audio = np.clip(audio.cpu().numpy().T, -1.0, 1.0) buf = io.BytesIO() sf.write(buf, np_audio, sample_rate, subtype="PCM_16", format="WAV") return base64.b64encode(buf.getvalue()).decode("ascii") class EndpointHandler: def __init__(self, path: str = "") -> None: # Load the full bag, then drop the other 3 sub-models so only the # drums specialist stays in memory. bag = get_model("htdemucs_ft") self.model = bag.models[BAG_INDEX] self.model.eval() self.device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) self.model.to(self.device) self.sample_rate = int(bag.samplerate) self.audio_channels = int(bag.audio_channels) self.sources = list(bag.sources) # ["drums","bass","other","vocals"] self.target_index = self.sources.index(TARGET_STEM) def __call__(self, data: dict[str, Any]) -> dict[str, Any]: if "inputs" not in data: return {"error": "Request body must include base64 audio under 'inputs'."} try: audio_bytes = base64.b64decode(data["inputs"]) wav_np, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True) except Exception as e: # noqa: BLE001 return {"error": f"Could not decode audio: {type(e).__name__}: {e}"} wav = torch.from_numpy(wav_np.T).contiguous() wav = convert_audio(wav, sr, self.sample_rate, self.audio_channels) wav = wav.unsqueeze(0).to(self.device) with torch.no_grad(): # apply_model on a single Model (not a BagOfModels) is supported # and runs only this specialist — 1/4 the cost of the full bag. stems = apply_model(self.model, wav, device=str(self.device), progress=False)[0] # stems: (n_sources, channels, samples). Only stems[target_index] # is meaningful for this specialist — the other rows are weakly # predicted by-products and should not be used. return { "drums": _audio_to_b64_wav(stems[self.target_index], self.sample_rate), "sample_rate": self.sample_rate, "duration_s": round(wav.shape[-1] / self.sample_rate, 3), }