File size: 3,379 Bytes
171bfcf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from typing import Dict, List, Any
from transformers import VitsModel, VitsTokenizer
import torch
import numpy as np
import base64
import soundfile as sf
import io
def normalize_waveform(waveform):
"""
Normalizes the waveform values to a range suitable for audio playback (e.g., -1 to 1).
Args:
waveform (np.ndarray): The waveform array to normalize.
Returns:
np.ndarray: The normalized waveform array.
"""
return waveform / np.max(np.abs(waveform)) # Normalize to -1 to 1 range
def waveform_to_bytes(waveform):
"""
Converts the waveform array to a byte sequence.
Args:
waveform (np.ndarray): The waveform array.
Returns:
bytes: The byte sequence representing the waveform.
"""
waveform_normalized = normalize_waveform(waveform) # Optional normalization
waveform_bytes = waveform_normalized.astype(np.float32).tobytes()
return waveform_bytes
def waveform_to_base64(waveform):
"""
Converts the waveform array to a base64-encoded string.
Args:
waveform (np.ndarray): The waveform array.
Returns:
str: The base64-encoded string representing the waveform.
"""
waveform_bytes = waveform_to_bytes(waveform)
byte_stream = BytesIO()
byte_stream.write(waveform_bytes)
byte_stream.seek(0) # Reset the stream pointer before encoding
base64_string = base64.b64encode(byte_stream.getvalue()).decode('utf-8')
return base64_string
class EndpointHandler:
def __init__(self, path: str):
"""
Initialize the endpoint with the model path.
Args:
path (str): The file path or model ID for loading the model.
"""
self.model = VitsModel.from_pretrained(path)
self.tokenizer = VitsTokenizer.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process a prediction request using the loaded model.
Args:
data (Dict[str, Any]): The request body containing 'inputs' and other parameters.
Returns:
List[Dict[str, Any]]: A list containing dictionaries with the model's output.
"""
inputs = data.get("inputs")
if not inputs:
raise ValueError("The 'inputs' key is required in the data dictionary and cannot be empty.")
if isinstance(inputs, str):
inputs = [inputs] # Convert to list to handle consistently as batch
if not all(isinstance(i, str) for i in inputs):
raise TypeError("All inputs must be strings.")
return self.generate_predictions(inputs)
def generate_predictions(self, texts: List[str]) -> List[Dict[str, Any]]:
"""
Generate predictions for a list of texts.
Args:
texts (List[str]): A list of texts for which to generate predictions.
Returns:
Base64 string
"""
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
with torch.no_grad():
output = self.model(**inputs).waveform
buffer = io.BytesIO()
sf.write(buffer, output.numpy()[0], self.model.config.sampling_rate, format='WAV')
buffer.seek(0) # Rewind the buffer to the beginning
base64_audio = base64.b64encode(buffer.read()).decode('utf-8')
return base64_audio |