File size: 3,441 Bytes
4c47e6f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | """
HF Inference Endpoint handler for the HT-Demucs FT **other** specialist.
This repo ships only sub-model 2 of the 4-bag htdemucs_ft ensemble
— the one trained to extract `other`. ~160 MB on disk and ~1/4 the inference
cost of the full bag, with the same per-stem quality as our v1.1 benchmark
(median other SDR = 6.34 dB).
If you need all 4 stems in one request, use the full ensemble:
https://huggingface.co/StemSplitio/htdemucs-ft-pytorch
Request shape:
POST /
Content-Type: application/json
{ "inputs": "<base64-encoded audio bytes>" }
Response shape:
{ "other": "<base64 WAV>", "sample_rate": 44100, "duration_s": 123.4 }
"""
from __future__ import annotations
import base64
import io
from typing import Any
import numpy as np
import soundfile as sf
import torch
from demucs.apply import apply_model
from demucs.audio import convert_audio
from demucs.pretrained import get_model
# Which sub-model of the htdemucs_ft bag to ship + which output index is ours.
BAG_INDEX = 2
TARGET_STEM = "other"
def _audio_to_b64_wav(audio: torch.Tensor, sample_rate: int) -> str:
np_audio = np.clip(audio.cpu().numpy().T, -1.0, 1.0)
buf = io.BytesIO()
sf.write(buf, np_audio, sample_rate, subtype="PCM_16", format="WAV")
return base64.b64encode(buf.getvalue()).decode("ascii")
class EndpointHandler:
def __init__(self, path: str = "") -> None:
# Load the full bag, then drop the other 3 sub-models so only the
# other specialist stays in memory.
bag = get_model("htdemucs_ft")
self.model = bag.models[BAG_INDEX]
self.model.eval()
self.device = torch.device(
"cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else
"cpu"
)
self.model.to(self.device)
self.sample_rate = int(bag.samplerate)
self.audio_channels = int(bag.audio_channels)
self.sources = list(bag.sources) # ["drums","bass","other","vocals"]
self.target_index = self.sources.index(TARGET_STEM)
def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
if "inputs" not in data:
return {"error": "Request body must include base64 audio under 'inputs'."}
try:
audio_bytes = base64.b64decode(data["inputs"])
wav_np, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True)
except Exception as e: # noqa: BLE001
return {"error": f"Could not decode audio: {type(e).__name__}: {e}"}
wav = torch.from_numpy(wav_np.T).contiguous()
wav = convert_audio(wav, sr, self.sample_rate, self.audio_channels)
wav = wav.unsqueeze(0).to(self.device)
with torch.no_grad():
# apply_model on a single Model (not a BagOfModels) is supported
# and runs only this specialist — 1/4 the cost of the full bag.
stems = apply_model(self.model, wav, device=str(self.device), progress=False)[0]
# stems: (n_sources, channels, samples). Only stems[target_index]
# is meaningful for this specialist — the other rows are weakly
# predicted by-products and should not be used.
return {
"other": _audio_to_b64_wav(stems[self.target_index], self.sample_rate),
"sample_rate": self.sample_rate,
"duration_s": round(wav.shape[-1] / self.sample_rate, 3),
}
|