|
|
|
|
|
import io, base64, wave |
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import AutoProcessor, CsmForConditionalGeneration |
|
|
|
|
|
SAMPLING_RATE = 24000 |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained(path) |
|
|
self.model = CsmForConditionalGeneration.from_pretrained( |
|
|
path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
).to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
def _wav_bytes(self, audio_f32, sr=SAMPLING_RATE): |
|
|
|
|
|
audio_i16 = np.clip(audio_f32, -1.0, 1.0) |
|
|
audio_i16 = (audio_i16 * 32767.0).astype(np.int16) |
|
|
buf = io.BytesIO() |
|
|
with wave.open(buf, "wb") as wf: |
|
|
wf.setnchannels(1) |
|
|
wf.setsampwidth(2) |
|
|
wf.setframerate(sr) |
|
|
wf.writeframes(audio_i16.tobytes()) |
|
|
return buf.getvalue() |
|
|
|
|
|
def __call__(self, data): |
|
|
""" |
|
|
Accepts either: |
|
|
{ "inputs": "Hello there", "parameters": {"speaker": 0, "max_length": 250} } |
|
|
or: |
|
|
{ |
|
|
"conversation": [ |
|
|
{"role":"0","content":[{"type":"text","text":"Say hi!"}]} |
|
|
], |
|
|
"parameters": {"speaker": 0} |
|
|
} |
|
|
""" |
|
|
params = data.get("parameters") or {} |
|
|
speaker = int(params.get("speaker", 0)) |
|
|
max_length = int(params.get("max_length", 250)) |
|
|
|
|
|
if "conversation" in data: |
|
|
|
|
|
conversation = data["conversation"] |
|
|
|
|
|
for msg in conversation: |
|
|
if "role" in msg: |
|
|
msg["role"] = str(speaker) |
|
|
else: |
|
|
text = data.get("inputs") or "" |
|
|
conversation = [{"role": str(speaker), |
|
|
"content": [{"type": "text", "text": text}]}] |
|
|
|
|
|
inputs = self.processor.apply_chat_template( |
|
|
conversation, |
|
|
tokenize=True, |
|
|
return_tensors="pt", |
|
|
return_dict=True |
|
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = self.model.generate( |
|
|
**inputs, |
|
|
max_length=max_length, |
|
|
output_audio=True, |
|
|
do_sample=True, |
|
|
temperature=0.8, |
|
|
top_p=0.9, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(out, np.ndarray): |
|
|
audio = out |
|
|
elif isinstance(out, list): |
|
|
|
|
|
first_item = out[0] if len(out) > 0 else out |
|
|
if hasattr(first_item, 'cpu'): |
|
|
audio = first_item.cpu().numpy() |
|
|
else: |
|
|
audio = np.array(first_item) |
|
|
elif hasattr(out, 'cpu'): |
|
|
|
|
|
audio = out.detach().cpu().numpy() |
|
|
else: |
|
|
|
|
|
audio = np.array(out) |
|
|
|
|
|
wav_b = self._wav_bytes(audio, SAMPLING_RATE) |
|
|
return { |
|
|
"audio_base64": base64.b64encode(wav_b).decode("ascii"), |
|
|
"sampling_rate": SAMPLING_RATE, |
|
|
"format": "wav" |
|
|
} |
|
|
|