File size: 1,072 Bytes
7d38920 b4e788b 7d38920 b4e788b 7d38920 b4e788b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import torch
from transformers import AutoProcessor, VibeVoiceAsrForConditionalGeneration
class EndpointHandler:
def __init__(self, path: str = ""):
self.processor = AutoProcessor.from_pretrained(path)
self.model = VibeVoiceAsrForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto",
)
def __call__(self, data):
inputs_data = data.pop("inputs", data)
prompt = data.pop("prompt", None)
inputs = self.processor.apply_transcription_request(
audio=inputs_data,
prompt=prompt,
return_tensors="pt",
).to(self.model.device, self.model.dtype)
with torch.no_grad():
output_ids = self.model.generate(**inputs)
generated_ids = output_ids[:, inputs["input_ids"].shape[1]:]
transcription = self.processor.decode(
generated_ids,
return_format="transcription_only",
)[0]
return {"text": transcription}
|