|
|
from typing import Dict, List, Any |
|
|
import torch |
|
|
import re |
|
|
import numpy as np |
|
|
import io |
|
|
import base64 |
|
|
from snac import SNAC |
|
|
import wave |
|
|
from unsloth import FastLanguageModel |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
|
self.model, self.tokenizer = FastLanguageModel.from_pretrained( |
|
|
model_name="kosinebolisa/igbo-tts-model", |
|
|
device_map="auto", |
|
|
load_in_4bit=False |
|
|
) |
|
|
|
|
|
|
|
|
FastLanguageModel.for_inference(self.model) |
|
|
|
|
|
|
|
|
|
|
|
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") |
|
|
self.snac_model = self.snac_model.to("cpu") |
|
|
self.snac_model.eval() |
|
|
|
|
|
|
|
|
self.number_words = { |
|
|
0: "oroghoro", 1: "otu", 2: "abụọ", 3: "atọ", 4: "anọ", 5: "ise", 6: "isii", 7: "asaa", 8: "asato", 9: "itoolu", |
|
|
10: "iri", 11: "iri na otu", 12: "iri na abụọ", 13: "iri na atọ", 14: "iri na anọ", 15: "iri na ise", |
|
|
16: "iri na isii", 17: "iri na asaa", 18: "iri na asato", 19: "iri na iteghete", |
|
|
20: "iri abụọ", 30: "otuz", 40: "iri anọ", 50: "iri ise", 60: "iri isii", 70: "iri asaa", |
|
|
80: "iri asatọ", 90: "iri itoolu", 100: "otu narị", 1000: "otu puku" |
|
|
} |
|
|
|
|
|
|
|
|
self.start_token = torch.tensor([[128259]], dtype=torch.int64) |
|
|
self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) |
|
|
self.pad_token_id = 128263 |
|
|
self.eos_token_id = 128258 |
|
|
self.code_start_token = 128257 |
|
|
self.code_offset = 128266 |
|
|
|
|
|
def number_to_words(self, number): |
|
|
"""Convert numbers to Igbo words.""" |
|
|
if number < 20: |
|
|
return self.number_words[number] |
|
|
elif number < 100: |
|
|
tens, unit = divmod(number, 10) |
|
|
return self.number_words[tens * 10] + (" na " + self.number_words[unit] if unit else "") |
|
|
elif number < 1000: |
|
|
hundreds, remainder = divmod(number, 100) |
|
|
base = (self.number_words[hundreds] + " narị") if hundreds > 1 else "otu narị" |
|
|
return base + (" na " + self.number_to_words(remainder) if remainder else "") |
|
|
elif number < 1_000_000: |
|
|
thousands, remainder = divmod(number, 1000) |
|
|
base = self.number_to_words(thousands) + " puku" if thousands > 1 else "otu puku" |
|
|
return base + (" na " + self.number_to_words(remainder) if remainder else "") |
|
|
elif number < 1_000_000_000: |
|
|
millions, remainder = divmod(number, 1_000_000) |
|
|
base = self.number_to_words(millions) + " nde" |
|
|
return base + (" na " + self.number_to_words(remainder) if remainder else "") |
|
|
elif number < 1_000_000_000_000: |
|
|
billions, remainder = divmod(number, 1_000_000_000) |
|
|
base = self.number_to_words(billions) + " ijeri" |
|
|
return base + (" na " + self.number_to_words(remainder) if remainder else "") |
|
|
else: |
|
|
return str(number) |
|
|
|
|
|
def replace_numbers_with_words(self, text): |
|
|
"""Replace numbers in text with Igbo words.""" |
|
|
def replace(match): |
|
|
number = int(match.group()) |
|
|
return self.number_to_words(number) |
|
|
|
|
|
return re.sub(r'\b\d+\b', replace, text) |
|
|
|
|
|
def preprocess_text(self, text, voice=None): |
|
|
"""Preprocess input text.""" |
|
|
|
|
|
processed_text = self.replace_numbers_with_words(text) |
|
|
|
|
|
|
|
|
if voice: |
|
|
processed_text = f"{voice}: {processed_text}" |
|
|
|
|
|
return processed_text |
|
|
|
|
|
def prepare_input_ids(self, texts): |
|
|
"""Prepare input IDs for batch processing.""" |
|
|
if isinstance(texts, str): |
|
|
texts = [texts] |
|
|
|
|
|
all_input_ids = [] |
|
|
for text in texts: |
|
|
input_ids = self.tokenizer(text, return_tensors="pt").input_ids |
|
|
all_input_ids.append(input_ids) |
|
|
|
|
|
|
|
|
all_modified_input_ids = [] |
|
|
for input_ids in all_input_ids: |
|
|
modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1) |
|
|
all_modified_input_ids.append(modified_input_ids) |
|
|
|
|
|
|
|
|
max_length = max([ids.shape[1] for ids in all_modified_input_ids]) |
|
|
all_padded_tensors = [] |
|
|
all_attention_masks = [] |
|
|
|
|
|
for modified_input_ids in all_modified_input_ids: |
|
|
padding = max_length - modified_input_ids.shape[1] |
|
|
padded_tensor = torch.cat([ |
|
|
torch.full((1, padding), self.pad_token_id, dtype=torch.int64), |
|
|
modified_input_ids |
|
|
], dim=1) |
|
|
attention_mask = torch.cat([ |
|
|
torch.zeros((1, padding), dtype=torch.int64), |
|
|
torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64) |
|
|
], dim=1) |
|
|
all_padded_tensors.append(padded_tensor) |
|
|
all_attention_masks.append(attention_mask) |
|
|
|
|
|
input_ids = torch.cat(all_padded_tensors, dim=0).to(self.model.device) |
|
|
attention_mask = torch.cat(all_attention_masks, dim=0).to(self.model.device) |
|
|
|
|
|
return input_ids, attention_mask |
|
|
|
|
|
def generate_codes(self, input_ids, attention_mask, **generation_params): |
|
|
"""Generate audio codes using the language model.""" |
|
|
default_params = { |
|
|
'max_new_tokens': 1200, |
|
|
'do_sample': True, |
|
|
'temperature': 0.6, |
|
|
'top_p': 0.95, |
|
|
'repetition_penalty': 1.1, |
|
|
'num_return_sequences': 1, |
|
|
'eos_token_id': self.eos_token_id, |
|
|
'use_cache': True |
|
|
} |
|
|
default_params.update(generation_params) |
|
|
|
|
|
generated_ids = self.model.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
**default_params |
|
|
) |
|
|
|
|
|
return generated_ids |
|
|
|
|
|
def process_generated_codes(self, generated_ids): |
|
|
"""Process generated token IDs to extract audio codes.""" |
|
|
|
|
|
token_indices = (generated_ids == self.code_start_token).nonzero(as_tuple=True) |
|
|
|
|
|
if len(token_indices[1]) > 0: |
|
|
last_occurrence_idx = token_indices[1][-1].item() |
|
|
cropped_tensor = generated_ids[:, last_occurrence_idx+1:] |
|
|
else: |
|
|
cropped_tensor = generated_ids |
|
|
|
|
|
|
|
|
processed_rows = [] |
|
|
for row in cropped_tensor: |
|
|
masked_row = row[row != self.eos_token_id] |
|
|
processed_rows.append(masked_row) |
|
|
|
|
|
|
|
|
code_lists = [] |
|
|
for row in processed_rows: |
|
|
row_length = row.size(0) |
|
|
new_length = (row_length // 7) * 7 |
|
|
trimmed_row = row[:new_length] |
|
|
trimmed_row = [t.item() - self.code_offset for t in trimmed_row] |
|
|
code_lists.append(trimmed_row) |
|
|
|
|
|
return code_lists |
|
|
|
|
|
def redistribute_codes(self, code_list): |
|
|
"""Redistribute codes into layers for SNAC decoding.""" |
|
|
layer_1 = [] |
|
|
layer_2 = [] |
|
|
layer_3 = [] |
|
|
|
|
|
for i in range((len(code_list)+1)//7): |
|
|
if 7*i < len(code_list): |
|
|
layer_1.append(code_list[7*i]) |
|
|
if 7*i+1 < len(code_list): |
|
|
layer_2.append(code_list[7*i+1]-4096) |
|
|
if 7*i+2 < len(code_list): |
|
|
layer_3.append(code_list[7*i+2]-(2*4096)) |
|
|
if 7*i+3 < len(code_list): |
|
|
layer_3.append(code_list[7*i+3]-(3*4096)) |
|
|
if 7*i+4 < len(code_list): |
|
|
layer_2.append(code_list[7*i+4]-(4*4096)) |
|
|
if 7*i+5 < len(code_list): |
|
|
layer_3.append(code_list[7*i+5]-(5*4096)) |
|
|
if 7*i+6 < len(code_list): |
|
|
layer_3.append(code_list[7*i+6]-(6*4096)) |
|
|
|
|
|
codes = [ |
|
|
torch.tensor(layer_1).unsqueeze(0), |
|
|
torch.tensor(layer_2).unsqueeze(0), |
|
|
torch.tensor(layer_3).unsqueeze(0) |
|
|
] |
|
|
|
|
|
|
|
|
self.snac_model.to("cpu") |
|
|
audio_hat = self.snac_model.decode(codes) |
|
|
|
|
|
return audio_hat |
|
|
|
|
|
def audio_to_wav_bytes(self, audio_tensor, sample_rate=24000): |
|
|
"""Convert audio tensor to WAV bytes.""" |
|
|
audio_np = audio_tensor.detach().squeeze().to("cpu").numpy() |
|
|
|
|
|
|
|
|
audio_np = np.clip(audio_np, -1.0, 1.0) |
|
|
audio_int16 = (audio_np * 32767).astype(np.int16) |
|
|
|
|
|
|
|
|
wav_buffer = io.BytesIO() |
|
|
with wave.open(wav_buffer, 'wb') as wav_file: |
|
|
wav_file.setnchannels(1) |
|
|
wav_file.setsampwidth(2) |
|
|
wav_file.setframerate(sample_rate) |
|
|
wav_file.writeframes(audio_int16.tobytes()) |
|
|
|
|
|
wav_buffer.seek(0) |
|
|
return wav_buffer.getvalue() |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Main inference function for TTS. |
|
|
|
|
|
Args: |
|
|
data (Dict[str, Any]): Input data containing: |
|
|
- inputs: Text string or list of text strings to synthesize |
|
|
- parameters (optional): Additional parameters like voice, generation settings |
|
|
|
|
|
Returns: |
|
|
List[Dict[str, Any]]: List containing audio data in base64 format |
|
|
""" |
|
|
try: |
|
|
|
|
|
inputs = data.get("inputs", "") |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
if not inputs: |
|
|
return [{"error": "No input text provided"}] |
|
|
|
|
|
|
|
|
voice = parameters.get("voice", None) |
|
|
generation_params = {k: v for k, v in parameters.items() if k != "voice"} |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
texts = [inputs] |
|
|
else: |
|
|
texts = inputs |
|
|
|
|
|
processed_texts = [self.preprocess_text(text, voice) for text in texts] |
|
|
|
|
|
|
|
|
input_ids, attention_mask = self.prepare_input_ids(processed_texts) |
|
|
|
|
|
|
|
|
generated_ids = self.generate_codes(input_ids, attention_mask, **generation_params) |
|
|
|
|
|
|
|
|
code_lists = self.process_generated_codes(generated_ids) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for i, code_list in enumerate(code_lists): |
|
|
try: |
|
|
|
|
|
audio_tensor = self.redistribute_codes(code_list) |
|
|
|
|
|
|
|
|
wav_bytes = self.audio_to_wav_bytes(audio_tensor) |
|
|
|
|
|
|
|
|
audio_b64 = base64.b64encode(wav_bytes).decode('utf-8') |
|
|
|
|
|
results.append({ |
|
|
"text": texts[i] if i < len(texts) else processed_texts[i], |
|
|
"audio": audio_b64, |
|
|
"content_type": "audio/wav", |
|
|
"sample_rate": 24000 |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
results.append({ |
|
|
"text": texts[i] if i < len(texts) else processed_texts[i], |
|
|
"error": f"Audio generation failed: {str(e)}" |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
return [{"error": f"TTS processing failed: {str(e)}"}] |