# handler.py import io, base64, wave import numpy as np import torch from transformers import AutoProcessor, CsmForConditionalGeneration SAMPLING_RATE = 24000 # CSM outputs 24 kHz mono class EndpointHandler: def __init__(self, path=""): self.device = "cuda" if torch.cuda.is_available() else "cpu" # Load processor + model specific to CSM self.processor = AutoProcessor.from_pretrained(path) self.model = CsmForConditionalGeneration.from_pretrained( path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(self.device) self.model.eval() def _wav_bytes(self, audio_f32, sr=SAMPLING_RATE): # audio_f32 in [-1, 1], 1-D numpy array audio_i16 = np.clip(audio_f32, -1.0, 1.0) audio_i16 = (audio_i16 * 32767.0).astype(np.int16) buf = io.BytesIO() with wave.open(buf, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) # 16-bit wf.setframerate(sr) wf.writeframes(audio_i16.tobytes()) return buf.getvalue() def __call__(self, data): """ Accepts either: { "inputs": "Hello there", "parameters": {"speaker": 0, "max_length": 250} } or: { "conversation": [ {"role":"0","content":[{"type":"text","text":"Say hi!"}]} ], "parameters": {"speaker": 0} } """ params = data.get("parameters") or {} speaker = int(params.get("speaker", 0)) # Keep as integer for consistency max_length = int(params.get("max_length", 250)) if "conversation" in data: # If conversation is provided, ensure speaker consistency conversation = data["conversation"] # Override speaker in conversation if provided in parameters for msg in conversation: if "role" in msg: msg["role"] = str(speaker) else: text = data.get("inputs") or "" conversation = [{"role": str(speaker), "content": [{"type": "text", "text": text}]}] inputs = self.processor.apply_chat_template( conversation, tokenize=True, return_tensors="pt", return_dict=True ).to(self.device) with torch.no_grad(): out = self.model.generate( **inputs, max_length=max_length, output_audio=True, # CSM returns audio do_sample=True, temperature=0.8, top_p=0.9, ) # Convert to base64 WAV for the endpoint response # Handle different output formats from CSM model if isinstance(out, np.ndarray): audio = out elif isinstance(out, list): # If output is a list, take the first element and handle it first_item = out[0] if len(out) > 0 else out if hasattr(first_item, 'cpu'): audio = first_item.cpu().numpy() else: audio = np.array(first_item) elif hasattr(out, 'cpu'): # If it's a tensor (including CUDA tensors), move to CPU and convert to numpy audio = out.detach().cpu().numpy() else: # Fallback: try to convert to numpy array audio = np.array(out) wav_b = self._wav_bytes(audio, SAMPLING_RATE) return { "audio_base64": base64.b64encode(wav_b).decode("ascii"), "sampling_rate": SAMPLING_RATE, "format": "wav" }