File size: 3,659 Bytes
8e9a34d 4b34089 8e9a34d e467ea5 8e9a34d e467ea5 8e9a34d e467ea5 8e9a34d e467ea5 8e9a34d 4b34089 8e9a34d 4b34089 8e9a34d 4b34089 8e9a34d f94872a 4b34089 8e9a34d 39064bf a971d4b 39064bf a971d4b 39064bf 8e9a34d 4b34089 8e9a34d 4b34089 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
# handler.py
import io, base64, wave
import numpy as np
import torch
from transformers import AutoProcessor, CsmForConditionalGeneration
SAMPLING_RATE = 24000 # CSM outputs 24 kHz mono
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load processor + model specific to CSM
self.processor = AutoProcessor.from_pretrained(path)
self.model = CsmForConditionalGeneration.from_pretrained(
path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(self.device)
self.model.eval()
def _wav_bytes(self, audio_f32, sr=SAMPLING_RATE):
# audio_f32 in [-1, 1], 1-D numpy array
audio_i16 = np.clip(audio_f32, -1.0, 1.0)
audio_i16 = (audio_i16 * 32767.0).astype(np.int16)
buf = io.BytesIO()
with wave.open(buf, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2) # 16-bit
wf.setframerate(sr)
wf.writeframes(audio_i16.tobytes())
return buf.getvalue()
def __call__(self, data):
"""
Accepts either:
{ "inputs": "Hello there", "parameters": {"speaker": 0, "max_length": 250} }
or:
{
"conversation": [
{"role":"0","content":[{"type":"text","text":"Say hi!"}]}
],
"parameters": {"speaker": 0}
}
"""
params = data.get("parameters") or {}
speaker = int(params.get("speaker", 0)) # Keep as integer for consistency
max_length = int(params.get("max_length", 250))
if "conversation" in data:
# If conversation is provided, ensure speaker consistency
conversation = data["conversation"]
# Override speaker in conversation if provided in parameters
for msg in conversation:
if "role" in msg:
msg["role"] = str(speaker)
else:
text = data.get("inputs") or ""
conversation = [{"role": str(speaker),
"content": [{"type": "text", "text": text}]}]
inputs = self.processor.apply_chat_template(
conversation,
tokenize=True,
return_tensors="pt",
return_dict=True
).to(self.device)
with torch.no_grad():
out = self.model.generate(
**inputs,
max_length=max_length,
output_audio=True, # CSM returns audio
do_sample=True,
temperature=0.8,
top_p=0.9,
)
# Convert to base64 WAV for the endpoint response
# Handle different output formats from CSM model
if isinstance(out, np.ndarray):
audio = out
elif isinstance(out, list):
# If output is a list, take the first element and handle it
first_item = out[0] if len(out) > 0 else out
if hasattr(first_item, 'cpu'):
audio = first_item.cpu().numpy()
else:
audio = np.array(first_item)
elif hasattr(out, 'cpu'):
# If it's a tensor (including CUDA tensors), move to CPU and convert to numpy
audio = out.detach().cpu().numpy()
else:
# Fallback: try to convert to numpy array
audio = np.array(out)
wav_b = self._wav_bytes(audio, SAMPLING_RATE)
return {
"audio_base64": base64.b64encode(wav_b).decode("ascii"),
"sampling_rate": SAMPLING_RATE,
"format": "wav"
}
|