File size: 3,912 Bytes
71605c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
HF Inference Endpoint handler for HT-Demucs FT.

When deployed to an HF Inference Endpoint, HF instantiates EndpointHandler
once at container startup (downloading the demucs checkpoints into the
container cache), then calls __call__ on every HTTP request.

Request shape:
    POST /
    Content-Type: application/json
    {
      "inputs": "<base64-encoded audio bytes; any libsndfile-readable format>",
      "parameters": {
        "stems": ["vocals", "drums", "bass", "other"]  // optional, defaults to all 4
      }
    }

Response shape:
    {
      "vocals":  "<base64 WAV>",
      "drums":   "<base64 WAV>",
      "bass":    "<base64 WAV>",
      "other":   "<base64 WAV>",
      "sample_rate": 44100,
      "duration_s":  123.4
    }

To deploy:
    1) Create the endpoint in the HF UI (Deploy -> Inference Endpoints on the
       model card), choose a GPU instance (T4 small minimum; L4 recommended)
    2) Send requests as shown above.

Or skip self-hosting and use the StemSplit API:
    https://stemsplit.io/developers
"""
from __future__ import annotations

import base64
import io
from typing import Any

import numpy as np
import soundfile as sf
import torch
from demucs.apply import apply_model
from demucs.audio import convert_audio
from demucs.pretrained import get_model

DEFAULT_STEMS = ("vocals", "drums", "bass", "other")


def _audio_to_b64_wav(audio: torch.Tensor, sample_rate: int) -> str:
    """Encode a (channels, samples) FP32 tensor as base64-PCM16 WAV."""
    np_audio = audio.cpu().numpy().T  # -> (samples, channels)
    np_audio = np.clip(np_audio, -1.0, 1.0)
    buf = io.BytesIO()
    sf.write(buf, np_audio, sample_rate, subtype="PCM_16", format="WAV")
    return base64.b64encode(buf.getvalue()).decode("ascii")


class EndpointHandler:
    """HF Inference Endpoint entrypoint."""

    def __init__(self, path: str = "") -> None:
        self.model = get_model("htdemucs_ft")
        self.model.eval()
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else
            "mps" if torch.backends.mps.is_available() else
            "cpu"
        )
        self.model.to(self.device)
        self.sample_rate = int(self.model.samplerate)
        self.audio_channels = int(self.model.audio_channels)
        self.sources = list(self.model.sources)

    def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
        if "inputs" not in data:
            return {"error": "Request body must include base64 audio under 'inputs'."}

        audio_bytes = base64.b64decode(data["inputs"])
        try:
            wav_np, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True)
        except Exception as e:  # noqa: BLE001
            return {"error": f"Could not decode audio: {type(e).__name__}: {e}"}

        # wav_np: (samples, channels) -> (channels, samples) FP32
        wav = torch.from_numpy(wav_np.T).contiguous()
        wav = convert_audio(wav, sr, self.sample_rate, self.audio_channels)
        wav = wav.unsqueeze(0).to(self.device)  # (1, channels, samples)

        # Optional stem filter
        params = data.get("parameters", {}) or {}
        requested_stems = [s for s in params.get("stems", DEFAULT_STEMS) if s in self.sources]
        if not requested_stems:
            requested_stems = list(self.sources)

        with torch.no_grad():
            # apply_model handles overlap-add segmentation internally
            stems = apply_model(self.model, wav, device=str(self.device), progress=False)[0]
            # stems: (n_sources, channels, samples) on `self.device`

        out: dict[str, Any] = {
            "sample_rate": self.sample_rate,
            "duration_s": round(wav.shape[-1] / self.sample_rate, 3),
        }
        for stem in requested_stems:
            idx = self.sources.index(stem)
            out[stem] = _audio_to_b64_wav(stems[idx], self.sample_rate)
        return out