Krio-TTS / handler.py
FarmerlineML's picture
Create handler.py
171bfcf verified
from typing import Dict, List, Any
from transformers import VitsModel, VitsTokenizer
import torch
import numpy as np
import base64
import soundfile as sf
import io
def normalize_waveform(waveform):
"""
Normalizes the waveform values to a range suitable for audio playback (e.g., -1 to 1).
Args:
waveform (np.ndarray): The waveform array to normalize.
Returns:
np.ndarray: The normalized waveform array.
"""
return waveform / np.max(np.abs(waveform)) # Normalize to -1 to 1 range
def waveform_to_bytes(waveform):
"""
Converts the waveform array to a byte sequence.
Args:
waveform (np.ndarray): The waveform array.
Returns:
bytes: The byte sequence representing the waveform.
"""
waveform_normalized = normalize_waveform(waveform) # Optional normalization
waveform_bytes = waveform_normalized.astype(np.float32).tobytes()
return waveform_bytes
def waveform_to_base64(waveform):
"""
Converts the waveform array to a base64-encoded string.
Args:
waveform (np.ndarray): The waveform array.
Returns:
str: The base64-encoded string representing the waveform.
"""
waveform_bytes = waveform_to_bytes(waveform)
byte_stream = BytesIO()
byte_stream.write(waveform_bytes)
byte_stream.seek(0) # Reset the stream pointer before encoding
base64_string = base64.b64encode(byte_stream.getvalue()).decode('utf-8')
return base64_string
class EndpointHandler:
def __init__(self, path: str):
"""
Initialize the endpoint with the model path.
Args:
path (str): The file path or model ID for loading the model.
"""
self.model = VitsModel.from_pretrained(path)
self.tokenizer = VitsTokenizer.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process a prediction request using the loaded model.
Args:
data (Dict[str, Any]): The request body containing 'inputs' and other parameters.
Returns:
List[Dict[str, Any]]: A list containing dictionaries with the model's output.
"""
inputs = data.get("inputs")
if not inputs:
raise ValueError("The 'inputs' key is required in the data dictionary and cannot be empty.")
if isinstance(inputs, str):
inputs = [inputs] # Convert to list to handle consistently as batch
if not all(isinstance(i, str) for i in inputs):
raise TypeError("All inputs must be strings.")
return self.generate_predictions(inputs)
def generate_predictions(self, texts: List[str]) -> List[Dict[str, Any]]:
"""
Generate predictions for a list of texts.
Args:
texts (List[str]): A list of texts for which to generate predictions.
Returns:
Base64 string
"""
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
with torch.no_grad():
output = self.model(**inputs).waveform
buffer = io.BytesIO()
sf.write(buffer, output.numpy()[0], self.model.config.sampling_rate, format='WAV')
buffer.seek(0) # Rewind the buffer to the beginning
base64_audio = base64.b64encode(buffer.read()).decode('utf-8')
return base64_audio