|
|
from typing import Dict, List, Any |
|
|
from transformers import VitsModel, VitsTokenizer |
|
|
import torch |
|
|
import numpy as np |
|
|
import base64 |
|
|
import soundfile as sf |
|
|
import io |
|
|
|
|
|
|
|
|
def normalize_waveform(waveform): |
|
|
""" |
|
|
Normalizes the waveform values to a range suitable for audio playback (e.g., -1 to 1). |
|
|
Args: |
|
|
waveform (np.ndarray): The waveform array to normalize. |
|
|
Returns: |
|
|
np.ndarray: The normalized waveform array. |
|
|
""" |
|
|
return waveform / np.max(np.abs(waveform)) |
|
|
|
|
|
|
|
|
def waveform_to_bytes(waveform): |
|
|
""" |
|
|
Converts the waveform array to a byte sequence. |
|
|
Args: |
|
|
waveform (np.ndarray): The waveform array. |
|
|
Returns: |
|
|
bytes: The byte sequence representing the waveform. |
|
|
""" |
|
|
waveform_normalized = normalize_waveform(waveform) |
|
|
waveform_bytes = waveform_normalized.astype(np.float32).tobytes() |
|
|
return waveform_bytes |
|
|
|
|
|
|
|
|
def waveform_to_base64(waveform): |
|
|
""" |
|
|
Converts the waveform array to a base64-encoded string. |
|
|
Args: |
|
|
waveform (np.ndarray): The waveform array. |
|
|
Returns: |
|
|
str: The base64-encoded string representing the waveform. |
|
|
""" |
|
|
waveform_bytes = waveform_to_bytes(waveform) |
|
|
byte_stream = BytesIO() |
|
|
byte_stream.write(waveform_bytes) |
|
|
byte_stream.seek(0) |
|
|
base64_string = base64.b64encode(byte_stream.getvalue()).decode('utf-8') |
|
|
return base64_string |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str): |
|
|
""" |
|
|
Initialize the endpoint with the model path. |
|
|
Args: |
|
|
path (str): The file path or model ID for loading the model. |
|
|
""" |
|
|
self.model = VitsModel.from_pretrained(path) |
|
|
self.tokenizer = VitsTokenizer.from_pretrained(path) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process a prediction request using the loaded model. |
|
|
Args: |
|
|
data (Dict[str, Any]): The request body containing 'inputs' and other parameters. |
|
|
Returns: |
|
|
List[Dict[str, Any]]: A list containing dictionaries with the model's output. |
|
|
""" |
|
|
inputs = data.get("inputs") |
|
|
if not inputs: |
|
|
raise ValueError("The 'inputs' key is required in the data dictionary and cannot be empty.") |
|
|
|
|
|
if isinstance(inputs, str): |
|
|
inputs = [inputs] |
|
|
|
|
|
if not all(isinstance(i, str) for i in inputs): |
|
|
raise TypeError("All inputs must be strings.") |
|
|
|
|
|
return self.generate_predictions(inputs) |
|
|
|
|
|
def generate_predictions(self, texts: List[str]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Generate predictions for a list of texts. |
|
|
Args: |
|
|
texts (List[str]): A list of texts for which to generate predictions. |
|
|
Returns: |
|
|
Base64 string |
|
|
""" |
|
|
inputs = self.tokenizer(texts, return_tensors="pt", padding=True) |
|
|
with torch.no_grad(): |
|
|
output = self.model(**inputs).waveform |
|
|
|
|
|
buffer = io.BytesIO() |
|
|
sf.write(buffer, output.numpy()[0], self.model.config.sampling_rate, format='WAV') |
|
|
buffer.seek(0) |
|
|
|
|
|
base64_audio = base64.b64encode(buffer.read()).decode('utf-8') |
|
|
return base64_audio |