hypaai commited on
Commit
bfcc307
·
verified ·
1 Parent(s): 6d8831f

Update handler.py

Browse files

Updating the handler so it supports cloning on the fly and returns the input ids & output ids.

Files changed (1) hide show
  1. handler.py +35 -12
handler.py CHANGED
@@ -32,7 +32,6 @@ class EndpointHandler:
32
 
33
  # Move to devices
34
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
35
- # self.model.to(self.device)
36
 
37
  # Load SNAC model for audio decoding
38
  try:
@@ -147,6 +146,7 @@ class EndpointHandler:
147
  # Preprocess input data before inference
148
 
149
  self.voice_cloning = data.get("clone", False)
 
150
 
151
  # Extract parameters from request
152
  target_text = data["inputs"]
@@ -159,12 +159,27 @@ class EndpointHandler:
159
  repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
160
 
161
  if self.voice_cloning:
162
- """Handle voice cloning using cloning features"""
163
-
164
- if not cloning_features:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  raise ValueError("No cloning features were provided")
166
  else:
167
- # Decode back into tensors
168
  enrollment_data = torch.load(io.BytesIO(base64.b64decode(cloning_features)))
169
 
170
  # Process pre-tokenized enrollment_data
@@ -187,13 +202,11 @@ class EndpointHandler:
187
 
188
  # Final input tensor
189
  input_ids = torch.cat(input_sequence, dim=1)
190
-
191
- # Heuristic to determine max_new_tokens based on empirical relationship
192
- # between the length of the prompt ids and the length of the generated ids
193
- prompt_ids = self.encode_text(target_text)
194
- max_new_tokens = int(prompt_ids.size()[1] * 20 + 200)
195
 
 
 
196
  input_ids = input_ids.to(self.device)
 
197
 
198
  else:
199
  # Handle standard text-to-speech
@@ -237,7 +250,11 @@ class EndpointHandler:
237
  # Forward pass through the model
238
  generated_ids = self.model.generate(prompt_string, sampling_params)
239
 
240
- return torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0)
 
 
 
 
241
 
242
  def __call__(self, data):
243
 
@@ -293,7 +310,10 @@ class EndpointHandler:
293
  audio_hat = self.snac_model.decode(codes)
294
  return audio_hat
295
 
296
- def postprocess(self, generated_ids):
 
 
 
297
 
298
  if self.voice_cloning:
299
  """
@@ -357,4 +377,7 @@ class EndpointHandler:
357
  "audio_sample": audio_sample,
358
  "audio_b64": audio_b64,
359
  "sample_rate": 24000,
 
 
 
360
  }
 
32
 
33
  # Move to devices
34
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
35
 
36
  # Load SNAC model for audio decoding
37
  try:
 
146
  # Preprocess input data before inference
147
 
148
  self.voice_cloning = data.get("clone", False)
149
+ clone_on_the_fly = data.get("clone_on_the_fly", False)
150
 
151
  # Extract parameters from request
152
  target_text = data["inputs"]
 
159
  repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
160
 
161
  if self.voice_cloning:
162
+ if clone_on_the_fly:
163
+ # Clone using text-audio enrollment pair
164
+ enrollment_pairs = data.get("enrollments", [])
165
+ enrollment_data = []
166
+
167
+ # Raise error if no enrollment is provided
168
+ if not enrollment_pairs:
169
+ raise ValueError("No enrollment pairs provided")
170
+
171
+ for text, base64_audio in enrollment_pairs:
172
+ text_ids = self.encode_text(text).cpu()
173
+ audio_codes = self.encode_audio(base64_audio)
174
+ enrollment_data.append({
175
+ "text_ids": text_ids,
176
+ "audio_codes": audio_codes
177
+ })
178
+
179
+ elif not cloning_features:
180
  raise ValueError("No cloning features were provided")
181
  else:
182
+ # Clone using enrollment features gotten earlier
183
  enrollment_data = torch.load(io.BytesIO(base64.b64decode(cloning_features)))
184
 
185
  # Process pre-tokenized enrollment_data
 
202
 
203
  # Final input tensor
204
  input_ids = torch.cat(input_sequence, dim=1)
 
 
 
 
 
205
 
206
+ # Create attention mask and move tensors to device
207
+ attention_mask = torch.ones_like(input_ids)
208
  input_ids = input_ids.to(self.device)
209
+ attention_mask = attention_mask.to(self.device)
210
 
211
  else:
212
  # Handle standard text-to-speech
 
250
  # Forward pass through the model
251
  generated_ids = self.model.generate(prompt_string, sampling_params)
252
 
253
+ # return torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0)
254
+ return {
255
+ "gen_ids": torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0),
256
+ "input_ids": input_ids
257
+ }
258
 
259
  def __call__(self, data):
260
 
 
310
  audio_hat = self.snac_model.decode(codes)
311
  return audio_hat
312
 
313
+ def postprocess(self, model_outputs):
314
+
315
+ generated_ids = model_outputs["gen_ids"]
316
+ input_ids = model_outputs["input_ids"]
317
 
318
  if self.voice_cloning:
319
  """
 
377
  "audio_sample": audio_sample,
378
  "audio_b64": audio_b64,
379
  "sample_rate": 24000,
380
+ "gen_ids": generated_ids,
381
+ "input_ids_len": input_ids.shape[1],
382
+ "gen_ids_len": generated_ids.shape[1]
383
  }