File size: 3,659 Bytes
8e9a34d
 
4b34089
8e9a34d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e467ea5
8e9a34d
 
 
e467ea5
8e9a34d
e467ea5
 
 
 
8e9a34d
 
e467ea5
8e9a34d
4b34089
8e9a34d
 
 
 
 
 
4b34089
 
8e9a34d
4b34089
8e9a34d
 
f94872a
 
 
4b34089
 
8e9a34d
39064bf
 
 
 
a971d4b
 
 
 
 
 
39064bf
a971d4b
 
39064bf
 
 
 
8e9a34d
4b34089
8e9a34d
 
 
4b34089
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# handler.py
import io, base64, wave
import numpy as np
import torch
from transformers import AutoProcessor, CsmForConditionalGeneration

SAMPLING_RATE = 24000  # CSM outputs 24 kHz mono

class EndpointHandler:
    def __init__(self, path=""):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        # Load processor + model specific to CSM
        self.processor = AutoProcessor.from_pretrained(path)
        self.model = CsmForConditionalGeneration.from_pretrained(
            path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        ).to(self.device)
        self.model.eval()

    def _wav_bytes(self, audio_f32, sr=SAMPLING_RATE):
        # audio_f32 in [-1, 1], 1-D numpy array
        audio_i16 = np.clip(audio_f32, -1.0, 1.0)
        audio_i16 = (audio_i16 * 32767.0).astype(np.int16)
        buf = io.BytesIO()
        with wave.open(buf, "wb") as wf:
            wf.setnchannels(1)
            wf.setsampwidth(2)  # 16-bit
            wf.setframerate(sr)
            wf.writeframes(audio_i16.tobytes())
        return buf.getvalue()

    def __call__(self, data):
        """
        Accepts either:
          { "inputs": "Hello there", "parameters": {"speaker": 0, "max_length": 250} }
        or:
          {
            "conversation": [
              {"role":"0","content":[{"type":"text","text":"Say hi!"}]}
            ],
            "parameters": {"speaker": 0}
          }
        """
        params = data.get("parameters") or {}
        speaker = int(params.get("speaker", 0))  # Keep as integer for consistency
        max_length = int(params.get("max_length", 250))

        if "conversation" in data:
            # If conversation is provided, ensure speaker consistency
            conversation = data["conversation"]
            # Override speaker in conversation if provided in parameters
            for msg in conversation:
                if "role" in msg:
                    msg["role"] = str(speaker)
        else:
            text = data.get("inputs") or ""
            conversation = [{"role": str(speaker),
                             "content": [{"type": "text", "text": text}]}]

        inputs = self.processor.apply_chat_template(
            conversation,
            tokenize=True,
            return_tensors="pt",
            return_dict=True
        ).to(self.device)

        with torch.no_grad():
            out = self.model.generate(
                **inputs,
                max_length=max_length,
                output_audio=True,   # CSM returns audio
                do_sample=True,
                temperature=0.8,
                top_p=0.9,
            )

        # Convert to base64 WAV for the endpoint response
        # Handle different output formats from CSM model
        if isinstance(out, np.ndarray):
            audio = out
        elif isinstance(out, list):
            # If output is a list, take the first element and handle it
            first_item = out[0] if len(out) > 0 else out
            if hasattr(first_item, 'cpu'):
                audio = first_item.cpu().numpy()
            else:
                audio = np.array(first_item)
        elif hasattr(out, 'cpu'):
            # If it's a tensor (including CUDA tensors), move to CPU and convert to numpy
            audio = out.detach().cpu().numpy()
        else:
            # Fallback: try to convert to numpy array
            audio = np.array(out)

        wav_b = self._wav_bytes(audio, SAMPLING_RATE)
        return {
            "audio_base64": base64.b64encode(wav_b).decode("ascii"),
            "sampling_rate": SAMPLING_RATE,
            "format": "wav"
        }