Update handler.py
Browse filesUpdating the handler so it supports cloning on the fly and returns the input ids & output ids.
- handler.py +35 -12
handler.py
CHANGED
|
@@ -32,7 +32,6 @@ class EndpointHandler:
|
|
| 32 |
|
| 33 |
# Move to devices
|
| 34 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
-
# self.model.to(self.device)
|
| 36 |
|
| 37 |
# Load SNAC model for audio decoding
|
| 38 |
try:
|
|
@@ -147,6 +146,7 @@ class EndpointHandler:
|
|
| 147 |
# Preprocess input data before inference
|
| 148 |
|
| 149 |
self.voice_cloning = data.get("clone", False)
|
|
|
|
| 150 |
|
| 151 |
# Extract parameters from request
|
| 152 |
target_text = data["inputs"]
|
|
@@ -159,12 +159,27 @@ class EndpointHandler:
|
|
| 159 |
repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
|
| 160 |
|
| 161 |
if self.voice_cloning:
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
raise ValueError("No cloning features were provided")
|
| 166 |
else:
|
| 167 |
-
#
|
| 168 |
enrollment_data = torch.load(io.BytesIO(base64.b64decode(cloning_features)))
|
| 169 |
|
| 170 |
# Process pre-tokenized enrollment_data
|
|
@@ -187,13 +202,11 @@ class EndpointHandler:
|
|
| 187 |
|
| 188 |
# Final input tensor
|
| 189 |
input_ids = torch.cat(input_sequence, dim=1)
|
| 190 |
-
|
| 191 |
-
# Heuristic to determine max_new_tokens based on empirical relationship
|
| 192 |
-
# between the length of the prompt ids and the length of the generated ids
|
| 193 |
-
prompt_ids = self.encode_text(target_text)
|
| 194 |
-
max_new_tokens = int(prompt_ids.size()[1] * 20 + 200)
|
| 195 |
|
|
|
|
|
|
|
| 196 |
input_ids = input_ids.to(self.device)
|
|
|
|
| 197 |
|
| 198 |
else:
|
| 199 |
# Handle standard text-to-speech
|
|
@@ -237,7 +250,11 @@ class EndpointHandler:
|
|
| 237 |
# Forward pass through the model
|
| 238 |
generated_ids = self.model.generate(prompt_string, sampling_params)
|
| 239 |
|
| 240 |
-
return torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
def __call__(self, data):
|
| 243 |
|
|
@@ -293,7 +310,10 @@ class EndpointHandler:
|
|
| 293 |
audio_hat = self.snac_model.decode(codes)
|
| 294 |
return audio_hat
|
| 295 |
|
| 296 |
-
def postprocess(self,
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
if self.voice_cloning:
|
| 299 |
"""
|
|
@@ -357,4 +377,7 @@ class EndpointHandler:
|
|
| 357 |
"audio_sample": audio_sample,
|
| 358 |
"audio_b64": audio_b64,
|
| 359 |
"sample_rate": 24000,
|
|
|
|
|
|
|
|
|
|
| 360 |
}
|
|
|
|
| 32 |
|
| 33 |
# Move to devices
|
| 34 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 35 |
|
| 36 |
# Load SNAC model for audio decoding
|
| 37 |
try:
|
|
|
|
| 146 |
# Preprocess input data before inference
|
| 147 |
|
| 148 |
self.voice_cloning = data.get("clone", False)
|
| 149 |
+
clone_on_the_fly = data.get("clone_on_the_fly", False)
|
| 150 |
|
| 151 |
# Extract parameters from request
|
| 152 |
target_text = data["inputs"]
|
|
|
|
| 159 |
repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
|
| 160 |
|
| 161 |
if self.voice_cloning:
|
| 162 |
+
if clone_on_the_fly:
|
| 163 |
+
# Clone using text-audio enrollment pair
|
| 164 |
+
enrollment_pairs = data.get("enrollments", [])
|
| 165 |
+
enrollment_data = []
|
| 166 |
+
|
| 167 |
+
# Raise error if no enrollment is provided
|
| 168 |
+
if not enrollment_pairs:
|
| 169 |
+
raise ValueError("No enrollment pairs provided")
|
| 170 |
+
|
| 171 |
+
for text, base64_audio in enrollment_pairs:
|
| 172 |
+
text_ids = self.encode_text(text).cpu()
|
| 173 |
+
audio_codes = self.encode_audio(base64_audio)
|
| 174 |
+
enrollment_data.append({
|
| 175 |
+
"text_ids": text_ids,
|
| 176 |
+
"audio_codes": audio_codes
|
| 177 |
+
})
|
| 178 |
+
|
| 179 |
+
elif not cloning_features:
|
| 180 |
raise ValueError("No cloning features were provided")
|
| 181 |
else:
|
| 182 |
+
# Clone using enrollment features gotten earlier
|
| 183 |
enrollment_data = torch.load(io.BytesIO(base64.b64decode(cloning_features)))
|
| 184 |
|
| 185 |
# Process pre-tokenized enrollment_data
|
|
|
|
| 202 |
|
| 203 |
# Final input tensor
|
| 204 |
input_ids = torch.cat(input_sequence, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
+
# Create attention mask and move tensors to device
|
| 207 |
+
attention_mask = torch.ones_like(input_ids)
|
| 208 |
input_ids = input_ids.to(self.device)
|
| 209 |
+
attention_mask = attention_mask.to(self.device)
|
| 210 |
|
| 211 |
else:
|
| 212 |
# Handle standard text-to-speech
|
|
|
|
| 250 |
# Forward pass through the model
|
| 251 |
generated_ids = self.model.generate(prompt_string, sampling_params)
|
| 252 |
|
| 253 |
+
# return torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0)
|
| 254 |
+
return {
|
| 255 |
+
"gen_ids": torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0),
|
| 256 |
+
"input_ids": input_ids
|
| 257 |
+
}
|
| 258 |
|
| 259 |
def __call__(self, data):
|
| 260 |
|
|
|
|
| 310 |
audio_hat = self.snac_model.decode(codes)
|
| 311 |
return audio_hat
|
| 312 |
|
| 313 |
+
def postprocess(self, model_outputs):
|
| 314 |
+
|
| 315 |
+
generated_ids = model_outputs["gen_ids"]
|
| 316 |
+
input_ids = model_outputs["input_ids"]
|
| 317 |
|
| 318 |
if self.voice_cloning:
|
| 319 |
"""
|
|
|
|
| 377 |
"audio_sample": audio_sample,
|
| 378 |
"audio_b64": audio_b64,
|
| 379 |
"sample_rate": 24000,
|
| 380 |
+
"gen_ids": generated_ids,
|
| 381 |
+
"input_ids_len": input_ids.shape[1],
|
| 382 |
+
"gen_ids_len": generated_ids.shape[1]
|
| 383 |
}
|