FarmerlineML commited on
Commit
171bfcf
·
verified ·
1 Parent(s): fdd7166

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +97 -0
handler.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import VitsModel, VitsTokenizer
3
+ import torch
4
+ import numpy as np
5
+ import base64
6
+ import soundfile as sf
7
+ import io
8
+
9
+
10
+ def normalize_waveform(waveform):
11
+ """
12
+ Normalizes the waveform values to a range suitable for audio playback (e.g., -1 to 1).
13
+ Args:
14
+ waveform (np.ndarray): The waveform array to normalize.
15
+ Returns:
16
+ np.ndarray: The normalized waveform array.
17
+ """
18
+ return waveform / np.max(np.abs(waveform)) # Normalize to -1 to 1 range
19
+
20
+
21
+ def waveform_to_bytes(waveform):
22
+ """
23
+ Converts the waveform array to a byte sequence.
24
+ Args:
25
+ waveform (np.ndarray): The waveform array.
26
+ Returns:
27
+ bytes: The byte sequence representing the waveform.
28
+ """
29
+ waveform_normalized = normalize_waveform(waveform) # Optional normalization
30
+ waveform_bytes = waveform_normalized.astype(np.float32).tobytes()
31
+ return waveform_bytes
32
+
33
+
34
+ def waveform_to_base64(waveform):
35
+ """
36
+ Converts the waveform array to a base64-encoded string.
37
+ Args:
38
+ waveform (np.ndarray): The waveform array.
39
+ Returns:
40
+ str: The base64-encoded string representing the waveform.
41
+ """
42
+ waveform_bytes = waveform_to_bytes(waveform)
43
+ byte_stream = BytesIO()
44
+ byte_stream.write(waveform_bytes)
45
+ byte_stream.seek(0) # Reset the stream pointer before encoding
46
+ base64_string = base64.b64encode(byte_stream.getvalue()).decode('utf-8')
47
+ return base64_string
48
+
49
+
50
+ class EndpointHandler:
51
+ def __init__(self, path: str):
52
+ """
53
+ Initialize the endpoint with the model path.
54
+ Args:
55
+ path (str): The file path or model ID for loading the model.
56
+ """
57
+ self.model = VitsModel.from_pretrained(path)
58
+ self.tokenizer = VitsTokenizer.from_pretrained(path)
59
+
60
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
61
+ """
62
+ Process a prediction request using the loaded model.
63
+ Args:
64
+ data (Dict[str, Any]): The request body containing 'inputs' and other parameters.
65
+ Returns:
66
+ List[Dict[str, Any]]: A list containing dictionaries with the model's output.
67
+ """
68
+ inputs = data.get("inputs")
69
+ if not inputs:
70
+ raise ValueError("The 'inputs' key is required in the data dictionary and cannot be empty.")
71
+
72
+ if isinstance(inputs, str):
73
+ inputs = [inputs] # Convert to list to handle consistently as batch
74
+
75
+ if not all(isinstance(i, str) for i in inputs):
76
+ raise TypeError("All inputs must be strings.")
77
+
78
+ return self.generate_predictions(inputs)
79
+
80
+ def generate_predictions(self, texts: List[str]) -> List[Dict[str, Any]]:
81
+ """
82
+ Generate predictions for a list of texts.
83
+ Args:
84
+ texts (List[str]): A list of texts for which to generate predictions.
85
+ Returns:
86
+ Base64 string
87
+ """
88
+ inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
89
+ with torch.no_grad():
90
+ output = self.model(**inputs).waveform
91
+
92
+ buffer = io.BytesIO()
93
+ sf.write(buffer, output.numpy()[0], self.model.config.sampling_rate, format='WAV')
94
+ buffer.seek(0) # Rewind the buffer to the beginning
95
+
96
+ base64_audio = base64.b64encode(buffer.read()).decode('utf-8')
97
+ return base64_audio