File size: 3,379 Bytes
171bfcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import Dict, List, Any
from transformers import VitsModel, VitsTokenizer
import torch
import numpy as np
import base64
import soundfile as sf
import io


def normalize_waveform(waveform):
    """
    Normalizes the waveform values to a range suitable for audio playback (e.g., -1 to 1).
    Args:
        waveform (np.ndarray): The waveform array to normalize.
    Returns:
        np.ndarray: The normalized waveform array.
    """
    return waveform / np.max(np.abs(waveform))  # Normalize to -1 to 1 range


def waveform_to_bytes(waveform):
    """
    Converts the waveform array to a byte sequence.
    Args:
        waveform (np.ndarray): The waveform array.
    Returns:
        bytes: The byte sequence representing the waveform.
    """
    waveform_normalized = normalize_waveform(waveform)  # Optional normalization
    waveform_bytes = waveform_normalized.astype(np.float32).tobytes()
    return waveform_bytes


def waveform_to_base64(waveform):
    """
    Converts the waveform array to a base64-encoded string.
    Args:
        waveform (np.ndarray): The waveform array.
    Returns:
        str: The base64-encoded string representing the waveform.
    """
    waveform_bytes = waveform_to_bytes(waveform)
    byte_stream = BytesIO()
    byte_stream.write(waveform_bytes)
    byte_stream.seek(0)  # Reset the stream pointer before encoding
    base64_string = base64.b64encode(byte_stream.getvalue()).decode('utf-8')
    return base64_string


class EndpointHandler:
    def __init__(self, path: str):
        """
        Initialize the endpoint with the model path.
        Args:
            path (str): The file path or model ID for loading the model.
        """
        self.model = VitsModel.from_pretrained(path)
        self.tokenizer = VitsTokenizer.from_pretrained(path)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process a prediction request using the loaded model.
        Args:
            data (Dict[str, Any]): The request body containing 'inputs' and other parameters.
        Returns:
            List[Dict[str, Any]]: A list containing dictionaries with the model's output.
        """
        inputs = data.get("inputs")
        if not inputs:
            raise ValueError("The 'inputs' key is required in the data dictionary and cannot be empty.")

        if isinstance(inputs, str):
            inputs = [inputs]  # Convert to list to handle consistently as batch

        if not all(isinstance(i, str) for i in inputs):
            raise TypeError("All inputs must be strings.")

        return self.generate_predictions(inputs)

    def generate_predictions(self, texts: List[str]) -> List[Dict[str, Any]]:
        """
        Generate predictions for a list of texts.
        Args:
            texts (List[str]): A list of texts for which to generate predictions.
        Returns:
            Base64 string
        """
        inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
        with torch.no_grad():
            output = self.model(**inputs).waveform
            
        buffer = io.BytesIO()
        sf.write(buffer, output.numpy()[0], self.model.config.sampling_rate, format='WAV')
        buffer.seek(0)  # Rewind the buffer to the beginning
        
        base64_audio = base64.b64encode(buffer.read()).decode('utf-8')
        return base64_audio