File size: 3,442 Bytes
551f275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
HF Inference Endpoint handler for the HT-Demucs FT **drums** specialist.

This repo ships only sub-model 0 of the 4-bag htdemucs_ft ensemble
— the one trained to extract `drums`. ~160 MB on disk and ~1/4 the inference
cost of the full bag, with the same per-stem quality as our v1.1 benchmark
(median drums SDR = 10.11 dB).

If you need all 4 stems in one request, use the full ensemble:
    https://huggingface.co/StemSplitio/htdemucs-ft-pytorch

Request shape:
    POST /
    Content-Type: application/json
    { "inputs": "<base64-encoded audio bytes>" }

Response shape:
    { "drums": "<base64 WAV>", "sample_rate": 44100, "duration_s": 123.4 }
"""
from __future__ import annotations

import base64
import io
from typing import Any

import numpy as np
import soundfile as sf
import torch
from demucs.apply import apply_model
from demucs.audio import convert_audio
from demucs.pretrained import get_model

# Which sub-model of the htdemucs_ft bag to ship + which output index is ours.
BAG_INDEX = 0
TARGET_STEM = "drums"


def _audio_to_b64_wav(audio: torch.Tensor, sample_rate: int) -> str:
    np_audio = np.clip(audio.cpu().numpy().T, -1.0, 1.0)
    buf = io.BytesIO()
    sf.write(buf, np_audio, sample_rate, subtype="PCM_16", format="WAV")
    return base64.b64encode(buf.getvalue()).decode("ascii")


class EndpointHandler:
    def __init__(self, path: str = "") -> None:
        # Load the full bag, then drop the other 3 sub-models so only the
        # drums specialist stays in memory.
        bag = get_model("htdemucs_ft")
        self.model = bag.models[BAG_INDEX]
        self.model.eval()
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else
            "mps" if torch.backends.mps.is_available() else
            "cpu"
        )
        self.model.to(self.device)
        self.sample_rate = int(bag.samplerate)
        self.audio_channels = int(bag.audio_channels)
        self.sources = list(bag.sources)  # ["drums","bass","other","vocals"]
        self.target_index = self.sources.index(TARGET_STEM)

    def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
        if "inputs" not in data:
            return {"error": "Request body must include base64 audio under 'inputs'."}

        try:
            audio_bytes = base64.b64decode(data["inputs"])
            wav_np, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True)
        except Exception as e:  # noqa: BLE001
            return {"error": f"Could not decode audio: {type(e).__name__}: {e}"}

        wav = torch.from_numpy(wav_np.T).contiguous()
        wav = convert_audio(wav, sr, self.sample_rate, self.audio_channels)
        wav = wav.unsqueeze(0).to(self.device)

        with torch.no_grad():
            # apply_model on a single Model (not a BagOfModels) is supported
            # and runs only this specialist — 1/4 the cost of the full bag.
            stems = apply_model(self.model, wav, device=str(self.device), progress=False)[0]
            # stems: (n_sources, channels, samples). Only stems[target_index]
            # is meaningful for this specialist — the other rows are weakly
            # predicted by-products and should not be used.

        return {
            "drums": _audio_to_b64_wav(stems[self.target_index], self.sample_rate),
            "sample_rate": self.sample_rate,
            "duration_s": round(wav.shape[-1] / self.sample_rate, 3),
        }