hypaai commited on
Commit
61d617a
·
verified ·
1 Parent(s): 33694a9

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +382 -0
  2. requirements.txt +6 -0
handler.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import librosa
5
+ import soundfile as sf
6
+ import traceback
7
+ import base64
8
+ import io
9
+ import wave
10
+
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from snac import SNAC
13
+ from vllm import LLM, SamplingParams
14
+
15
+ class EndpointHandler:
16
+ def __init__(self, path=""):
17
+
18
+ # Delimiter tokens as defined in Orpheus' vocabulary
19
+ self.START_OF_HUMAN = 128259
20
+ self.START_OF_TEXT = 128000
21
+ self.END_OF_TEXT = 128009
22
+ self.END_OF_HUMAN = 128260
23
+ self.START_OF_AI = 128261
24
+ self.START_OF_SPEECH = 128257
25
+ self.END_OF_SPEECH = 128258
26
+ self.END_OF_AI = 128262
27
+ self.AUDIO_TOKENS_START = 128266
28
+
29
+ # Load the models and tokenizer
30
+ self.model = LLM(path, max_model_len = 4096, gpu_memory_utilization = 0.3)
31
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
32
+
33
+ # Move to devices
34
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+ # Load SNAC model for audio decoding
37
+ try:
38
+ self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
39
+ self.snac_model.to(self.device)
40
+ except Exception as e:
41
+ raise RuntimeError(f"Failed to load SNAC model: {e}")
42
+
43
+ # Set up functions to format and encode text/audio
44
+ def encode_text(self, text):
45
+ return self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False)
46
+
47
+ def encode_audio(self, base64_audio_str):
48
+ audio_bytes = base64.b64decode(base64_audio_str)
49
+ audio_buffer = io.BytesIO(audio_bytes)
50
+ waveform, sr = sf.read(audio_buffer, dtype='float32')
51
+
52
+ if waveform.ndim > 1:
53
+ waveform = np.mean(waveform, axis=1)
54
+ if sr != 24000:
55
+ waveform = librosa.resample(waveform, orig_sr=sr, target_sr=24000)
56
+ return self.tokenize_audio(waveform)
57
+
58
+ def format_text_block(self, text_ids):
59
+ return [
60
+ torch.tensor([[self.START_OF_HUMAN]], dtype=torch.int64),
61
+ torch.tensor([[self.START_OF_TEXT]], dtype=torch.int64),
62
+ text_ids,
63
+ torch.tensor([[self.END_OF_TEXT]], dtype=torch.int64),
64
+ torch.tensor([[self.END_OF_HUMAN]], dtype=torch.int64)
65
+ ]
66
+
67
+ def format_audio_block(self, audio_codes):
68
+ return [
69
+ torch.tensor([[self.START_OF_AI]], dtype=torch.int64),
70
+ torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64),
71
+ torch.tensor([audio_codes], dtype=torch.int64),
72
+ torch.tensor([[self.END_OF_SPEECH]], dtype=torch.int64),
73
+ torch.tensor([[self.END_OF_AI]], dtype=torch.int64)
74
+ ]
75
+
76
+ def enroll_user(self, enrollment_pairs):
77
+ """
78
+ Parameters:
79
+ - enrollment_pairs: List of tuples (text, audio_data), where audio_data is
80
+ base64-encoded audio data
81
+ Returns:
82
+ - cloning_features (str): serialized enrollment data
83
+ """
84
+ enrollment_data = []
85
+
86
+ for text, base64_audio in enrollment_pairs:
87
+ text_ids = self.encode_text(text).cpu()
88
+ audio_codes = self.encode_audio(base64_audio)
89
+ enrollment_data.append({
90
+ "text_ids": text_ids,
91
+ "audio_codes": audio_codes
92
+ })
93
+
94
+ # Serialize enrollment data
95
+ buffer = io.BytesIO()
96
+ torch.save(enrollment_data, buffer)
97
+ buffer.seek(0)
98
+
99
+ # Encode as base64 string and assign to attribute
100
+ cloning_features = base64.b64encode(buffer.read()).decode('utf-8')
101
+ return cloning_features
102
+
103
+ def prepare_audio_tokens_for_decoder(self, audio_codes_list):
104
+ """
105
+ Given a list containing sequences of generated audio codes, do the following:
106
+ 1. Trim length to a multiple of 7 (SNAC decoder requires 7 tokens per audio frame)
107
+ 2. Adjust token values to SNAC decoder's expected range
108
+ """
109
+ modified_audio_codes_list = []
110
+ for audio_codes in audio_codes_list:
111
+
112
+ # Trim length to a multiple of 7
113
+ length = (audio_codes.size(0) // 7) * 7
114
+ trimmed = audio_codes[:length]
115
+
116
+ # Adjust token values to SNAC decoder's expected range
117
+ audio_codes = trimmed - self.AUDIO_TOKENS_START
118
+
119
+ # Add modified audio codes to list
120
+ modified_audio_codes_list.append(audio_codes)
121
+
122
+ return modified_audio_codes_list
123
+
124
+ # Convert audio sample to codes and reconstruct
125
+ def tokenize_audio(self, waveform):
126
+ waveform = torch.from_numpy(waveform).unsqueeze(0).unsqueeze(0).to(self.device)
127
+
128
+ with torch.inference_mode():
129
+ codes = self.snac_model.encode(waveform)
130
+
131
+ all_codes = []
132
+ for i in range(codes[0].shape[1]):
133
+
134
+ all_codes.append(codes[0][0][(1 * i) + 0].item() + self.AUDIO_TOKENS_START + (0 * 4096))
135
+ all_codes.append(codes[1][0][(2 * i) + 0].item() + self.AUDIO_TOKENS_START + (1 * 4096))
136
+ all_codes.append(codes[2][0][(4 * i) + 0].item() + self.AUDIO_TOKENS_START + (2 * 4096))
137
+ all_codes.append(codes[2][0][(4 * i) + 1].item() + self.AUDIO_TOKENS_START + (3 * 4096))
138
+ all_codes.append(codes[1][0][(2 * i) + 1].item() + self.AUDIO_TOKENS_START + (4 * 4096))
139
+ all_codes.append(codes[2][0][(4 * i) + 2].item() + self.AUDIO_TOKENS_START + (5 * 4096))
140
+ all_codes.append(codes[2][0][(4 * i) + 3].item() + self.AUDIO_TOKENS_START + (6 * 4096))
141
+
142
+ return all_codes
143
+
144
+ def preprocess(self, data):
145
+
146
+ # Preprocess input data before inference
147
+
148
+ self.voice_cloning = data.get("clone", False)
149
+ clone_on_the_fly = data.get("clone_on_the_fly", False)
150
+
151
+ # Extract parameters from request
152
+ target_text = data["inputs"]
153
+ parameters = data.get("parameters", {})
154
+ cloning_features = data.get("cloning_features", None)
155
+
156
+ temperature = float(parameters.get("temperature", 0.6))
157
+ top_p = float(parameters.get("top_p", 0.95))
158
+ max_new_tokens = int(parameters.get("max_new_tokens", 1200))
159
+ repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
160
+
161
+ if self.voice_cloning:
162
+ if clone_on_the_fly:
163
+ # Clone using text-audio enrollment pair
164
+ enrollment_pairs = data.get("enrollments", [])
165
+ enrollment_data = []
166
+
167
+ # Raise error if no enrollment is provided
168
+ if not enrollment_pairs:
169
+ raise ValueError("No enrollment pairs provided")
170
+
171
+ for text, base64_audio in enrollment_pairs:
172
+ text_ids = self.encode_text(text).cpu()
173
+ audio_codes = self.encode_audio(base64_audio)
174
+ enrollment_data.append({
175
+ "text_ids": text_ids,
176
+ "audio_codes": audio_codes
177
+ })
178
+
179
+ elif not cloning_features:
180
+ raise ValueError("No cloning features were provided")
181
+ else:
182
+ # Clone using enrollment features gotten earlier
183
+ enrollment_data = torch.load(io.BytesIO(base64.b64decode(cloning_features)))
184
+
185
+ # Process pre-tokenized enrollment_data
186
+ input_sequence = []
187
+ for item in enrollment_data:
188
+ text_ids = item["text_ids"]
189
+ audio_codes = item["audio_codes"]
190
+ input_sequence.extend(self.format_text_block(text_ids))
191
+ input_sequence.extend(self.format_audio_block(audio_codes))
192
+
193
+ # Append target text whose audio we want
194
+ target_text_ids = self.encode_text(target_text)
195
+ input_sequence.extend(self.format_text_block(target_text_ids))
196
+
197
+ # Start of target audio - audio codes to be completed by model
198
+ input_sequence.extend([
199
+ torch.tensor([[self.START_OF_AI]], dtype=torch.int64),
200
+ torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64)
201
+ ])
202
+
203
+ # Final input tensor
204
+ input_ids = torch.cat(input_sequence, dim=1)
205
+
206
+ # Create attention mask and move tensors to device
207
+ attention_mask = torch.ones_like(input_ids)
208
+ input_ids = input_ids.to(self.device)
209
+ attention_mask = attention_mask.to(self.device)
210
+
211
+ else:
212
+ # Handle standard text-to-speech
213
+
214
+ # Extract parameters from request
215
+ voice = parameters.get("voice", "Eniola")
216
+ prompt = f"{voice}: {target_text}"
217
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
218
+
219
+ # Add special tokens
220
+ input_ids = torch.cat(self.format_text_block(input_ids), dim=1)
221
+
222
+ # No need for padding as we're processing a single sequence
223
+ input_ids = input_ids.to(self.device)
224
+
225
+ return {
226
+ "input_ids": input_ids,
227
+ "temperature": temperature,
228
+ "top_p": top_p,
229
+ "max_new_tokens": max_new_tokens,
230
+ "repetition_penalty": repetition_penalty,
231
+ }
232
+
233
+ def inference(self, inputs):
234
+ """
235
+ Run model inference on the preprocessed inputs
236
+ """
237
+ # Extract parameters
238
+ input_ids = inputs["input_ids"]
239
+
240
+ sampling_params = SamplingParams(
241
+ temperature = inputs["temperature"],
242
+ top_p = inputs["top_p"],
243
+ max_tokens = inputs["max_new_tokens"],
244
+ repetition_penalty = inputs["repetition_penalty"],
245
+ stop_token_ids = [self.END_OF_SPEECH],
246
+ )
247
+
248
+ prompt_string = self.tokenizer.decode(input_ids[0])
249
+
250
+ # Forward pass through the model
251
+ generated_ids = self.model.generate(prompt_string, sampling_params)
252
+
253
+ # return torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0)
254
+ return {
255
+ "gen_ids": torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0),
256
+ "input_ids": input_ids
257
+ }
258
+
259
+ def __call__(self, data):
260
+
261
+ # Main entry point for the handler
262
+
263
+ try:
264
+ enroll_user = data.get("enroll_user", False)
265
+
266
+ if enroll_user:
267
+ # We extract cloning features for enrollment
268
+ enrollment_pairs = data.get("enrollments", [])
269
+ cloning_features = self.enroll_user(enrollment_pairs)
270
+ return {"cloning_features": cloning_features}
271
+ else:
272
+ # We want to generate speech using preset cloning features
273
+ preprocessed_inputs = self.preprocess(data)
274
+ model_outputs = self.inference(preprocessed_inputs)
275
+ response = self.postprocess(model_outputs)
276
+ return response
277
+
278
+ # Catch that error, baby
279
+ except Exception as e:
280
+ traceback.print_exc()
281
+ return {"error": str(e)}
282
+
283
+ # Postprocess generated ids
284
+ def convert_codes_to_waveform(self, code_list):
285
+ """
286
+ Reorganize tokens for SNAC decoding
287
+ """
288
+ layer_1 = [] # Coarsest layer
289
+ layer_2 = [] # Intermediate layer
290
+ layer_3 = [] # Finest layer
291
+
292
+ num_groups = len(code_list) // 7
293
+ for i in range(num_groups):
294
+ idx = 7 * i
295
+ layer_1.append(code_list[7 * i + 0] - (0 * 4096))
296
+ layer_2.append(code_list[7 * i + 1] - (1 * 4096))
297
+ layer_3.append(code_list[7 * i + 2] - (2 * 4096))
298
+ layer_3.append(code_list[7 * i + 3] - (3 * 4096))
299
+ layer_2.append(code_list[7 * i + 4] - (4 * 4096))
300
+ layer_3.append(code_list[7 * i + 5] - (5 * 4096))
301
+ layer_3.append(code_list[7 * i + 6] - (6 * 4096))
302
+
303
+ codes = [
304
+ torch.tensor(layer_1).unsqueeze(0).to(self.device),
305
+ torch.tensor(layer_2).unsqueeze(0).to(self.device),
306
+ torch.tensor(layer_3).unsqueeze(0).to(self.device)
307
+ ]
308
+
309
+ # Decode audio
310
+ audio_hat = self.snac_model.decode(codes)
311
+ return audio_hat
312
+
313
+ def postprocess(self, model_outputs):
314
+
315
+ generated_ids = model_outputs["gen_ids"]
316
+ input_ids = model_outputs["input_ids"]
317
+
318
+ if self.voice_cloning:
319
+ """
320
+ For cloning applications, use this postprocess function to get generated audio samples
321
+ """
322
+ # Modify audio codes to be digestible byb SNAC decoder
323
+ code_lists = self.prepare_audio_tokens_for_decoder(generated_ids)
324
+
325
+ # Generate audio from codes
326
+ temp = self.convert_codes_to_waveform(code_lists[0])
327
+ audio_sample = temp.detach().squeeze().to("cpu").numpy()
328
+
329
+ else:
330
+ """
331
+ Process generated tokens into audio
332
+ """
333
+ # Find Start of Audio token
334
+ token_indices = (generated_ids == self.START_OF_SPEECH).nonzero(as_tuple=True)
335
+
336
+ if len(token_indices[1]) > 0:
337
+ last_occurrence_idx = token_indices[1][-1].item()
338
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
339
+ else:
340
+ cropped_tensor = generated_ids
341
+
342
+ # Remove End of Audio tokens
343
+ processed_rows = []
344
+ for row in cropped_tensor:
345
+ masked_row = row[row != self.END_OF_SPEECH]
346
+ processed_rows.append(masked_row)
347
+
348
+ code_lists = self.prepare_audio_tokens_for_decoder(processed_rows)
349
+
350
+ # Generate audio from codes
351
+ audio_samples = []
352
+ for code_list in code_lists:
353
+ if len(code_list) > 0:
354
+ audio = self.convert_codes_to_waveform(code_list)
355
+ audio_samples.append(audio)
356
+ else:
357
+ raise ValueError("Empty code list, no audio to generate")
358
+
359
+ if not audio_samples:
360
+ return {"error": "No audio samples generated"}
361
+
362
+ # Return first (and only) audio sample
363
+ audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
364
+
365
+ # Convert float32 array to int16 for WAV format
366
+ audio_int16 = (audio_sample * 32767).astype(np.int16)
367
+
368
+ # Write to WAV in memory (float32 or int16 depending on your preference)
369
+ buffer = io.BytesIO()
370
+ sf.write(buffer, audio_sample, samplerate=24000, format='WAV', subtype='PCM_16') # or PCM_32
371
+ buffer.seek(0)
372
+
373
+ # Encode WAV bytes as base64
374
+ audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
375
+
376
+ return {
377
+ "audio_sample": audio_sample,
378
+ "audio_b64": audio_b64,
379
+ "sample_rate": 24000,
380
+ "input_ids_len": input_ids.shape[1],
381
+ "gen_ids_len": generated_ids.shape[1]
382
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ snac>=0.1.0
4
+ numpy>=1.20.0
5
+ protobuf
6
+ vllm