csm-1b / handler.py
farazmoradi98's picture
Add custom handler for TTS inference
f94872a verified
# handler.py
import io, base64, wave
import numpy as np
import torch
from transformers import AutoProcessor, CsmForConditionalGeneration
SAMPLING_RATE = 24000 # CSM outputs 24 kHz mono
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load processor + model specific to CSM
self.processor = AutoProcessor.from_pretrained(path)
self.model = CsmForConditionalGeneration.from_pretrained(
path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(self.device)
self.model.eval()
def _wav_bytes(self, audio_f32, sr=SAMPLING_RATE):
# audio_f32 in [-1, 1], 1-D numpy array
audio_i16 = np.clip(audio_f32, -1.0, 1.0)
audio_i16 = (audio_i16 * 32767.0).astype(np.int16)
buf = io.BytesIO()
with wave.open(buf, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2) # 16-bit
wf.setframerate(sr)
wf.writeframes(audio_i16.tobytes())
return buf.getvalue()
def __call__(self, data):
"""
Accepts either:
{ "inputs": "Hello there", "parameters": {"speaker": 0, "max_length": 250} }
or:
{
"conversation": [
{"role":"0","content":[{"type":"text","text":"Say hi!"}]}
],
"parameters": {"speaker": 0}
}
"""
params = data.get("parameters") or {}
speaker = int(params.get("speaker", 0)) # Keep as integer for consistency
max_length = int(params.get("max_length", 250))
if "conversation" in data:
# If conversation is provided, ensure speaker consistency
conversation = data["conversation"]
# Override speaker in conversation if provided in parameters
for msg in conversation:
if "role" in msg:
msg["role"] = str(speaker)
else:
text = data.get("inputs") or ""
conversation = [{"role": str(speaker),
"content": [{"type": "text", "text": text}]}]
inputs = self.processor.apply_chat_template(
conversation,
tokenize=True,
return_tensors="pt",
return_dict=True
).to(self.device)
with torch.no_grad():
out = self.model.generate(
**inputs,
max_length=max_length,
output_audio=True, # CSM returns audio
do_sample=True,
temperature=0.8,
top_p=0.9,
)
# Convert to base64 WAV for the endpoint response
# Handle different output formats from CSM model
if isinstance(out, np.ndarray):
audio = out
elif isinstance(out, list):
# If output is a list, take the first element and handle it
first_item = out[0] if len(out) > 0 else out
if hasattr(first_item, 'cpu'):
audio = first_item.cpu().numpy()
else:
audio = np.array(first_item)
elif hasattr(out, 'cpu'):
# If it's a tensor (including CUDA tensors), move to CPU and convert to numpy
audio = out.detach().cpu().numpy()
else:
# Fallback: try to convert to numpy array
audio = np.array(out)
wav_b = self._wav_bytes(audio, SAMPLING_RATE)
return {
"audio_base64": base64.b64encode(wav_b).decode("ascii"),
"sampling_rate": SAMPLING_RATE,
"format": "wav"
}