Updated handler.py with claude
Browse files- handler.py +186 -193
handler.py
CHANGED
|
@@ -1,20 +1,176 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
-
import soundfile as sf
|
| 4 |
-
import io
|
| 5 |
-
import os
|
| 6 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 7 |
-
from snac import SNAC
|
| 8 |
-
|
| 9 |
-
# --- Helper Function (can be outside or inside the class) ---
|
| 10 |
-
def redistribute_codes_static(code_list):
|
| 11 |
-
""" Reorganizes the flattened token list into three separate layers for SNAC. """
|
| 12 |
-
layer_1, layer_2, layer_3 = [], [], []
|
| 13 |
-
num_groups = len(code_list) // 7 # Use floor division
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
layer_1.append(code_list[idx])
|
| 19 |
layer_2.append(code_list[idx + 1] - 4096)
|
| 20 |
layer_3.append(code_list[idx + 2] - (2 * 4096))
|
|
@@ -22,185 +178,22 @@ def redistribute_codes_static(code_list):
|
|
| 22 |
layer_2.append(code_list[idx + 4] - (4 * 4096))
|
| 23 |
layer_3.append(code_list[idx + 5] - (5 * 4096))
|
| 24 |
layer_3.append(code_list[idx + 6] - (6 * 4096))
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
if layer_2: codes[1] = torch.tensor(layer_2).unsqueeze(0).long()
|
| 38 |
-
if layer_3: codes[2] = torch.tensor(layer_3).unsqueeze(0).long()
|
| 39 |
-
|
| 40 |
-
return codes
|
| 41 |
-
|
| 42 |
-
# --- Endpoint Handler Class ---
|
| 43 |
-
class EndpointHandler():
|
| 44 |
-
def __init__(self, path=""):
|
| 45 |
-
"""
|
| 46 |
-
Initializes the handler. Loads both Orpheus LLM and SNAC Vocoder.
|
| 47 |
-
'path' points to the directory containing the Orpheus model files specified in the endpoint config.
|
| 48 |
-
"""
|
| 49 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 50 |
-
print(f"Using device: {self.device}")
|
| 51 |
-
|
| 52 |
-
# Define Model Names/Paths
|
| 53 |
-
# Orpheus LLM path is determined by the endpoint configuration ('path' variable)
|
| 54 |
-
orpheus_model_path = path if path else "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
|
| 55 |
-
snac_model_name = "hubertsiuzdak/snac_24khz"
|
| 56 |
-
|
| 57 |
-
# Define Special Token IDs (matching your script)
|
| 58 |
-
self.start_human_token_id = 128259
|
| 59 |
-
self.end_text_token_id = 128009
|
| 60 |
-
self.end_human_token_id = 128260
|
| 61 |
-
# self.padding_token_id = 128263 # Not needed for single sequence generation
|
| 62 |
-
self.start_audio_token_id = 128257
|
| 63 |
-
self.end_audio_token_id = 128258
|
| 64 |
-
self.audio_code_offset = 128266
|
| 65 |
-
|
| 66 |
-
# Define sampling rate
|
| 67 |
-
self.sampling_rate = 24000
|
| 68 |
-
|
| 69 |
-
try:
|
| 70 |
-
# Load Orpheus LLM and Tokenizer
|
| 71 |
-
print(f"Loading Orpheus tokenizer from: {orpheus_model_path}")
|
| 72 |
-
self.tokenizer = AutoTokenizer.from_pretrained(orpheus_model_path)
|
| 73 |
-
print(f"Loading Orpheus model from: {orpheus_model_path}")
|
| 74 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
| 75 |
-
orpheus_model_path,
|
| 76 |
-
torch_dtype=torch.bfloat16 # Use bfloat16 as in your script
|
| 77 |
-
)
|
| 78 |
-
self.model.to(self.device)
|
| 79 |
-
self.model.eval() # Set model to evaluation mode
|
| 80 |
-
print("Orpheus model and tokenizer loaded successfully.")
|
| 81 |
-
|
| 82 |
-
# Load SNAC Vocoder
|
| 83 |
-
print(f"Loading SNAC model from: {snac_model_name}")
|
| 84 |
-
self.snac_model = SNAC.from_pretrained(snac_model_name)
|
| 85 |
-
self.snac_model.to(self.device) # Move SNAC to the same device
|
| 86 |
-
self.snac_model.eval() # Set model to evaluation mode
|
| 87 |
-
print("SNAC model loaded successfully.")
|
| 88 |
-
|
| 89 |
-
except Exception as e:
|
| 90 |
-
print(f"Error during model loading: {e}")
|
| 91 |
-
raise RuntimeError(f"Failed to load models.", e)
|
| 92 |
-
|
| 93 |
-
def __call__(self, data: dict) -> bytes:
|
| 94 |
"""
|
| 95 |
-
|
| 96 |
-
Expects data['inputs'] (text) and optionally data['parameters']
|
| 97 |
"""
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
raise ValueError("Missing 'inputs' key in request data")
|
| 103 |
-
|
| 104 |
-
parameters = data.pop("parameters", {})
|
| 105 |
-
# Default voice if not provided
|
| 106 |
-
voice = parameters.get("voice", "Eniola")
|
| 107 |
-
# Default generation parameters (merge with provided ones)
|
| 108 |
-
gen_params = {
|
| 109 |
-
"max_new_tokens": 1200,
|
| 110 |
-
"do_sample": True,
|
| 111 |
-
"temperature": 0.6,
|
| 112 |
-
"top_p": 0.95,
|
| 113 |
-
"repetition_penalty": 1.1,
|
| 114 |
-
"num_return_sequences": 1,
|
| 115 |
-
"eos_token_id": self.end_audio_token_id,
|
| 116 |
-
**parameters # Overwrite defaults with user params
|
| 117 |
-
}
|
| 118 |
-
# Remove non-generate params if they were passed
|
| 119 |
-
gen_params.pop("voice", None)
|
| 120 |
-
|
| 121 |
-
print(f"Received request: text='{text[:50]}...', voice='{voice}', params={gen_params}")
|
| 122 |
-
|
| 123 |
-
# --- Preprocess Text ---
|
| 124 |
-
prompt = f"{voice}: {text}"
|
| 125 |
-
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
| 126 |
-
|
| 127 |
-
# Add special tokens: SOH + Input Tokens + EOT + EOH
|
| 128 |
-
start_token_tensor = torch.tensor([[self.start_human_token_id]], dtype=torch.int64)
|
| 129 |
-
end_tokens_tensor = torch.tensor([[self.end_text_token_id, self.end_human_token_id]], dtype=torch.int64)
|
| 130 |
-
processed_input_ids = torch.cat([start_token_tensor, input_ids, end_tokens_tensor], dim=1).to(self.device)
|
| 131 |
-
|
| 132 |
-
# Create attention mask (all ones for single, unpadded sequence)
|
| 133 |
-
attention_mask = torch.ones_like(processed_input_ids).to(self.device)
|
| 134 |
-
|
| 135 |
-
print(f"Processed input shape: {processed_input_ids.shape}")
|
| 136 |
-
|
| 137 |
-
# --- Generate Audio Codes (LLM Inference) ---
|
| 138 |
-
with torch.no_grad():
|
| 139 |
-
generated_ids = self.model.generate(
|
| 140 |
-
input_ids=processed_input_ids,
|
| 141 |
-
attention_mask=attention_mask,
|
| 142 |
-
**gen_params
|
| 143 |
-
)
|
| 144 |
-
print(f"Generated IDs shape: {generated_ids.shape}")
|
| 145 |
-
|
| 146 |
-
# --- Process Generated Tokens (Extract Audio Codes) ---
|
| 147 |
-
# Find the last Start of Audio token
|
| 148 |
-
soa_indices = (generated_ids[0] == self.start_audio_token_id).nonzero(as_tuple=True)[0]
|
| 149 |
-
if len(soa_indices) == 0:
|
| 150 |
-
print("Warning: Start of Audio token (128257) not found in generated sequence!")
|
| 151 |
-
# Handle this case: maybe return error, or try processing from start?
|
| 152 |
-
# For now, let's assume it might still contain codes and try processing all generated *new* tokens
|
| 153 |
-
start_idx = processed_input_ids.shape[1] # Start after the input prompt
|
| 154 |
-
else:
|
| 155 |
-
start_idx = soa_indices[-1].item() + 1 # Start after the last SOA token
|
| 156 |
-
|
| 157 |
-
# Extract potential audio codes (after last SOA or after input)
|
| 158 |
-
cropped_tokens = generated_ids[0, start_idx:]
|
| 159 |
-
|
| 160 |
-
# Remove End of Audio tokens
|
| 161 |
-
audio_codes_raw = cropped_tokens[cropped_tokens != self.end_audio_token_id]
|
| 162 |
-
print(f"Extracted raw audio codes count: {len(audio_codes_raw)}")
|
| 163 |
-
|
| 164 |
-
if len(audio_codes_raw) == 0:
|
| 165 |
-
raise ValueError("No audio codes generated or extracted after processing.")
|
| 166 |
-
|
| 167 |
-
# --- Prepare Codes for SNAC Vocoder ---
|
| 168 |
-
# Adjust token values
|
| 169 |
-
adjusted_codes = [t.item() - self.audio_code_offset for t in audio_codes_raw]
|
| 170 |
-
|
| 171 |
-
# Trim to multiple of 7
|
| 172 |
-
num_codes = len(adjusted_codes)
|
| 173 |
-
valid_length = (num_codes // 7) * 7
|
| 174 |
-
if valid_length == 0:
|
| 175 |
-
raise ValueError(f"Not enough audio codes ({num_codes}) to form a multiple of 7 after processing.")
|
| 176 |
-
trimmed_codes = adjusted_codes[:valid_length]
|
| 177 |
-
print(f"Trimmed adjusted audio codes count: {len(trimmed_codes)}")
|
| 178 |
-
|
| 179 |
-
# --- Redistribute Codes ---
|
| 180 |
-
# Use static method or instance method, ensure tensors are on correct device
|
| 181 |
-
snac_input_codes = redistribute_codes_static(trimmed_codes)
|
| 182 |
-
snac_input_codes = [layer.to(self.device) for layer in snac_input_codes]
|
| 183 |
-
|
| 184 |
-
# --- Decode Audio (SNAC Inference) ---
|
| 185 |
-
print("Decoding audio with SNAC...")
|
| 186 |
-
with torch.no_grad():
|
| 187 |
-
audio_hat = self.snac_model.decode(snac_input_codes)
|
| 188 |
-
print(f"Decoded audio tensor shape: {audio_hat.shape}") # Should be [1, 1, num_samples]
|
| 189 |
-
|
| 190 |
-
# --- Postprocess Audio ---
|
| 191 |
-
# Move to CPU, remove batch/channel dims, convert to numpy
|
| 192 |
-
audio_waveform = audio_hat.detach().squeeze().cpu().numpy()
|
| 193 |
-
|
| 194 |
-
# --- Convert to WAV Bytes ---
|
| 195 |
-
buffer = io.BytesIO()
|
| 196 |
-
sf.write(buffer, audio_waveform, self.sampling_rate, format='WAV')
|
| 197 |
-
buffer.seek(0)
|
| 198 |
-
wav_bytes = buffer.read()
|
| 199 |
-
|
| 200 |
-
print(f"Generated {len(wav_bytes)} bytes of WAV audio.")
|
| 201 |
-
return wav_bytes
|
| 202 |
-
|
| 203 |
-
except Exception as e:
|
| 204 |
-
print(f"Error during inference call: {e}")
|
| 205 |
-
# Re-raise for endpoint framework
|
| 206 |
-
raise RuntimeError(f"Inference failed: {e}")
|
|
|
|
| 1 |
+
import os
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
| 4 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 5 |
+
from snac import SNAC
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
class EndpointHandler:
|
| 8 |
+
def __init__(self, path=""):
|
| 9 |
+
# Load the Orpheus model and tokenizer
|
| 10 |
+
self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
|
| 11 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 12 |
+
self.model_name,
|
| 13 |
+
torch_dtype=torch.bfloat16
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# Move model to GPU if available
|
| 17 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
+
self.model.to(self.device)
|
| 19 |
+
|
| 20 |
+
# Load tokenizer
|
| 21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 22 |
+
|
| 23 |
+
# Load SNAC model for audio decoding
|
| 24 |
+
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
| 25 |
+
self.snac_model.to(self.device)
|
| 26 |
+
|
| 27 |
+
# Special tokens
|
| 28 |
+
self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
|
| 29 |
+
self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
|
| 30 |
+
self.padding_token = 128263
|
| 31 |
+
self.start_audio_token = 128257 # Start of Audio token
|
| 32 |
+
self.end_audio_token = 128258 # End of Audio token
|
| 33 |
+
|
| 34 |
+
print(f"Model loaded on {self.device}")
|
| 35 |
+
|
| 36 |
+
def preprocess(self, data):
|
| 37 |
+
"""
|
| 38 |
+
Preprocess input data before inference
|
| 39 |
+
"""
|
| 40 |
+
inputs = data.pop("inputs", data)
|
| 41 |
+
|
| 42 |
+
# Extract parameters from request
|
| 43 |
+
text = inputs.get("text", "")
|
| 44 |
+
voice = inputs.get("voice", "tara")
|
| 45 |
+
temperature = float(inputs.get("temperature", 0.6))
|
| 46 |
+
top_p = float(inputs.get("top_p", 0.95))
|
| 47 |
+
max_new_tokens = int(inputs.get("max_new_tokens", 1200))
|
| 48 |
+
repetition_penalty = float(inputs.get("repetition_penalty", 1.1))
|
| 49 |
+
|
| 50 |
+
# Format prompt with voice
|
| 51 |
+
prompt = f"{voice}: {text}"
|
| 52 |
+
|
| 53 |
+
# Tokenize
|
| 54 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
| 55 |
+
|
| 56 |
+
# Add special tokens
|
| 57 |
+
modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
|
| 58 |
+
|
| 59 |
+
# No need for padding as we're processing a single sequence
|
| 60 |
+
input_ids = modified_input_ids.to(self.device)
|
| 61 |
+
attention_mask = torch.ones_like(input_ids)
|
| 62 |
+
|
| 63 |
+
return {
|
| 64 |
+
"input_ids": input_ids,
|
| 65 |
+
"attention_mask": attention_mask,
|
| 66 |
+
"temperature": temperature,
|
| 67 |
+
"top_p": top_p,
|
| 68 |
+
"max_new_tokens": max_new_tokens,
|
| 69 |
+
"repetition_penalty": repetition_penalty
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def inference(self, inputs):
|
| 73 |
+
"""
|
| 74 |
+
Run model inference on the preprocessed inputs
|
| 75 |
+
"""
|
| 76 |
+
# Extract parameters
|
| 77 |
+
input_ids = inputs["input_ids"]
|
| 78 |
+
attention_mask = inputs["attention_mask"]
|
| 79 |
+
temperature = inputs["temperature"]
|
| 80 |
+
top_p = inputs["top_p"]
|
| 81 |
+
max_new_tokens = inputs["max_new_tokens"]
|
| 82 |
+
repetition_penalty = inputs["repetition_penalty"]
|
| 83 |
+
|
| 84 |
+
# Generate output tokens
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
generated_ids = self.model.generate(
|
| 87 |
+
input_ids=input_ids,
|
| 88 |
+
attention_mask=attention_mask,
|
| 89 |
+
max_new_tokens=max_new_tokens,
|
| 90 |
+
do_sample=True,
|
| 91 |
+
temperature=temperature,
|
| 92 |
+
top_p=top_p,
|
| 93 |
+
repetition_penalty=repetition_penalty,
|
| 94 |
+
num_return_sequences=1,
|
| 95 |
+
eos_token_id=self.end_audio_token,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
return generated_ids
|
| 99 |
+
|
| 100 |
+
def postprocess(self, generated_ids):
|
| 101 |
+
"""
|
| 102 |
+
Process generated tokens into audio
|
| 103 |
+
"""
|
| 104 |
+
# Find Start of Audio token
|
| 105 |
+
token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True)
|
| 106 |
+
|
| 107 |
+
if len(token_indices[1]) > 0:
|
| 108 |
+
last_occurrence_idx = token_indices[1][-1].item()
|
| 109 |
+
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
|
| 110 |
+
else:
|
| 111 |
+
cropped_tensor = generated_ids
|
| 112 |
+
|
| 113 |
+
# Remove End of Audio tokens
|
| 114 |
+
processed_rows = []
|
| 115 |
+
for row in cropped_tensor:
|
| 116 |
+
masked_row = row[row != self.end_audio_token]
|
| 117 |
+
processed_rows.append(masked_row)
|
| 118 |
+
|
| 119 |
+
# Prepare audio codes
|
| 120 |
+
code_lists = []
|
| 121 |
+
for row in processed_rows:
|
| 122 |
+
row_length = row.size(0)
|
| 123 |
+
# Ensure length is multiple of 7 for SNAC
|
| 124 |
+
new_length = (row_length // 7) * 7
|
| 125 |
+
trimmed_row = row[:new_length]
|
| 126 |
+
trimmed_row = [t.item() - 128266 for t in trimmed_row] # Adjust token values
|
| 127 |
+
code_lists.append(trimmed_row)
|
| 128 |
+
|
| 129 |
+
# Generate audio from codes
|
| 130 |
+
audio_samples = []
|
| 131 |
+
for code_list in code_lists:
|
| 132 |
+
audio = self.redistribute_codes(code_list)
|
| 133 |
+
audio_samples.append(audio)
|
| 134 |
+
|
| 135 |
+
# Return first (and only) audio sample
|
| 136 |
+
audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
|
| 137 |
+
|
| 138 |
+
# Convert to base64 for transmission
|
| 139 |
+
import base64
|
| 140 |
+
import io
|
| 141 |
+
import wave
|
| 142 |
+
|
| 143 |
+
# Convert float32 array to int16 for WAV format
|
| 144 |
+
audio_int16 = (audio_sample * 32767).astype(np.int16)
|
| 145 |
+
|
| 146 |
+
# Create WAV in memory
|
| 147 |
+
with io.BytesIO() as wav_io:
|
| 148 |
+
with wave.open(wav_io, 'wb') as wav_file:
|
| 149 |
+
wav_file.setnchannels(1) # Mono
|
| 150 |
+
wav_file.setsampwidth(2) # 16-bit
|
| 151 |
+
wav_file.setframerate(24000) # 24kHz
|
| 152 |
+
wav_file.writeframes(audio_int16.tobytes())
|
| 153 |
+
wav_data = wav_io.getvalue()
|
| 154 |
+
|
| 155 |
+
# Encode as base64
|
| 156 |
+
audio_b64 = base64.b64encode(wav_data).decode('utf-8')
|
| 157 |
+
|
| 158 |
+
return {
|
| 159 |
+
"audio_b64": audio_b64,
|
| 160 |
+
"sample_rate": 24000
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
def redistribute_codes(self, code_list):
|
| 164 |
+
"""
|
| 165 |
+
Reorganize tokens for SNAC decoding
|
| 166 |
+
"""
|
| 167 |
+
layer_1 = [] # Coarsest layer
|
| 168 |
+
layer_2 = [] # Intermediate layer
|
| 169 |
+
layer_3 = [] # Finest layer
|
| 170 |
+
|
| 171 |
+
num_groups = len(code_list) // 7
|
| 172 |
+
for i in range(num_groups):
|
| 173 |
+
idx = 7 * i
|
| 174 |
layer_1.append(code_list[idx])
|
| 175 |
layer_2.append(code_list[idx + 1] - 4096)
|
| 176 |
layer_3.append(code_list[idx + 2] - (2 * 4096))
|
|
|
|
| 178 |
layer_2.append(code_list[idx + 4] - (4 * 4096))
|
| 179 |
layer_3.append(code_list[idx + 5] - (5 * 4096))
|
| 180 |
layer_3.append(code_list[idx + 6] - (6 * 4096))
|
| 181 |
+
|
| 182 |
+
codes = [
|
| 183 |
+
torch.tensor(layer_1).unsqueeze(0).to(self.device),
|
| 184 |
+
torch.tensor(layer_2).unsqueeze(0).to(self.device),
|
| 185 |
+
torch.tensor(layer_3).unsqueeze(0).to(self.device)
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
# Decode audio
|
| 189 |
+
audio_hat = self.snac_model.decode(codes)
|
| 190 |
+
return audio_hat
|
| 191 |
+
|
| 192 |
+
def __call__(self, data):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
"""
|
| 194 |
+
Main entry point for the handler
|
|
|
|
| 195 |
"""
|
| 196 |
+
preprocessed_inputs = self.preprocess(data)
|
| 197 |
+
model_outputs = self.inference(preprocessed_inputs)
|
| 198 |
+
response = self.postprocess(model_outputs)
|
| 199 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|