from typing import Dict, List, Any from transformers import VitsModel, VitsTokenizer import torch import numpy as np import base64 import soundfile as sf import io def normalize_waveform(waveform): """ Normalizes the waveform values to a range suitable for audio playback (e.g., -1 to 1). Args: waveform (np.ndarray): The waveform array to normalize. Returns: np.ndarray: The normalized waveform array. """ return waveform / np.max(np.abs(waveform)) # Normalize to -1 to 1 range def waveform_to_bytes(waveform): """ Converts the waveform array to a byte sequence. Args: waveform (np.ndarray): The waveform array. Returns: bytes: The byte sequence representing the waveform. """ waveform_normalized = normalize_waveform(waveform) # Optional normalization waveform_bytes = waveform_normalized.astype(np.float32).tobytes() return waveform_bytes def waveform_to_base64(waveform): """ Converts the waveform array to a base64-encoded string. Args: waveform (np.ndarray): The waveform array. Returns: str: The base64-encoded string representing the waveform. """ waveform_bytes = waveform_to_bytes(waveform) byte_stream = BytesIO() byte_stream.write(waveform_bytes) byte_stream.seek(0) # Reset the stream pointer before encoding base64_string = base64.b64encode(byte_stream.getvalue()).decode('utf-8') return base64_string class EndpointHandler: def __init__(self, path: str): """ Initialize the endpoint with the model path. Args: path (str): The file path or model ID for loading the model. """ self.model = VitsModel.from_pretrained(path) self.tokenizer = VitsTokenizer.from_pretrained(path) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process a prediction request using the loaded model. Args: data (Dict[str, Any]): The request body containing 'inputs' and other parameters. Returns: List[Dict[str, Any]]: A list containing dictionaries with the model's output. """ inputs = data.get("inputs") if not inputs: raise ValueError("The 'inputs' key is required in the data dictionary and cannot be empty.") if isinstance(inputs, str): inputs = [inputs] # Convert to list to handle consistently as batch if not all(isinstance(i, str) for i in inputs): raise TypeError("All inputs must be strings.") return self.generate_predictions(inputs) def generate_predictions(self, texts: List[str]) -> List[Dict[str, Any]]: """ Generate predictions for a list of texts. Args: texts (List[str]): A list of texts for which to generate predictions. Returns: Base64 string """ inputs = self.tokenizer(texts, return_tensors="pt", padding=True) with torch.no_grad(): output = self.model(**inputs).waveform buffer = io.BytesIO() sf.write(buffer, output.numpy()[0], self.model.config.sampling_rate, format='WAV') buffer.seek(0) # Rewind the buffer to the beginning base64_audio = base64.b64encode(buffer.read()).decode('utf-8') return base64_audio