from typing import Dict, List, Any import torch import re import numpy as np import io import base64 from snac import SNAC import wave from unsloth import FastLanguageModel class EndpointHandler(): def __init__(self, path=""): # Load the model from Hugging Face using unsloth self.model, self.tokenizer = FastLanguageModel.from_pretrained( model_name="kosinebolisa/igbo-tts-model", device_map="auto", load_in_4bit=False ) # Enable inference mode for faster processing FastLanguageModel.for_inference(self.model) # Initialize SNAC model (you'll need to add the SNAC model loading here) # Assuming snac_model is available - you might need to load it separately self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") self.snac_model = self.snac_model.to("cpu") self.snac_model.eval() # Replace with actual SNAC loading # Igbo number word dictionary self.number_words = { 0: "oroghoro", 1: "otu", 2: "abụọ", 3: "atọ", 4: "anọ", 5: "ise", 6: "isii", 7: "asaa", 8: "asato", 9: "itoolu", 10: "iri", 11: "iri na otu", 12: "iri na abụọ", 13: "iri na atọ", 14: "iri na anọ", 15: "iri na ise", 16: "iri na isii", 17: "iri na asaa", 18: "iri na asato", 19: "iri na iteghete", 20: "iri abụọ", 30: "otuz", 40: "iri anọ", 50: "iri ise", 60: "iri isii", 70: "iri asaa", 80: "iri asatọ", 90: "iri itoolu", 100: "otu narị", 1000: "otu puku" } # Special tokens self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human self.pad_token_id = 128263 self.eos_token_id = 128258 self.code_start_token = 128257 self.code_offset = 128266 def number_to_words(self, number): """Convert numbers to Igbo words.""" if number < 20: return self.number_words[number] elif number < 100: tens, unit = divmod(number, 10) return self.number_words[tens * 10] + (" na " + self.number_words[unit] if unit else "") elif number < 1000: hundreds, remainder = divmod(number, 100) base = (self.number_words[hundreds] + " narị") if hundreds > 1 else "otu narị" return base + (" na " + self.number_to_words(remainder) if remainder else "") elif number < 1_000_000: thousands, remainder = divmod(number, 1000) base = self.number_to_words(thousands) + " puku" if thousands > 1 else "otu puku" return base + (" na " + self.number_to_words(remainder) if remainder else "") elif number < 1_000_000_000: millions, remainder = divmod(number, 1_000_000) base = self.number_to_words(millions) + " nde" return base + (" na " + self.number_to_words(remainder) if remainder else "") elif number < 1_000_000_000_000: billions, remainder = divmod(number, 1_000_000_000) base = self.number_to_words(billions) + " ijeri" return base + (" na " + self.number_to_words(remainder) if remainder else "") else: return str(number) def replace_numbers_with_words(self, text): """Replace numbers in text with Igbo words.""" def replace(match): number = int(match.group()) return self.number_to_words(number) return re.sub(r'\b\d+\b', replace, text) def preprocess_text(self, text, voice=None): """Preprocess input text.""" # Normalize numbers processed_text = self.replace_numbers_with_words(text) # Add voice prefix if specified if voice: processed_text = f"{voice}: {processed_text}" return processed_text def prepare_input_ids(self, texts): """Prepare input IDs for batch processing.""" if isinstance(texts, str): texts = [texts] all_input_ids = [] for text in texts: input_ids = self.tokenizer(text, return_tensors="pt").input_ids all_input_ids.append(input_ids) # Add special tokens and padding all_modified_input_ids = [] for input_ids in all_input_ids: modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1) all_modified_input_ids.append(modified_input_ids) # Pad sequences max_length = max([ids.shape[1] for ids in all_modified_input_ids]) all_padded_tensors = [] all_attention_masks = [] for modified_input_ids in all_modified_input_ids: padding = max_length - modified_input_ids.shape[1] padded_tensor = torch.cat([ torch.full((1, padding), self.pad_token_id, dtype=torch.int64), modified_input_ids ], dim=1) attention_mask = torch.cat([ torch.zeros((1, padding), dtype=torch.int64), torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64) ], dim=1) all_padded_tensors.append(padded_tensor) all_attention_masks.append(attention_mask) input_ids = torch.cat(all_padded_tensors, dim=0).to(self.model.device) attention_mask = torch.cat(all_attention_masks, dim=0).to(self.model.device) return input_ids, attention_mask def generate_codes(self, input_ids, attention_mask, **generation_params): """Generate audio codes using the language model.""" default_params = { 'max_new_tokens': 1200, 'do_sample': True, 'temperature': 0.6, 'top_p': 0.95, 'repetition_penalty': 1.1, 'num_return_sequences': 1, 'eos_token_id': self.eos_token_id, 'use_cache': True } default_params.update(generation_params) generated_ids = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, **default_params ) return generated_ids def process_generated_codes(self, generated_ids): """Process generated token IDs to extract audio codes.""" # Find the last occurrence of code start token token_indices = (generated_ids == self.code_start_token).nonzero(as_tuple=True) if len(token_indices[1]) > 0: last_occurrence_idx = token_indices[1][-1].item() cropped_tensor = generated_ids[:, last_occurrence_idx+1:] else: cropped_tensor = generated_ids # Remove EOS tokens processed_rows = [] for row in cropped_tensor: masked_row = row[row != self.eos_token_id] processed_rows.append(masked_row) # Convert to code lists code_lists = [] for row in processed_rows: row_length = row.size(0) new_length = (row_length // 7) * 7 trimmed_row = row[:new_length] trimmed_row = [t.item() - self.code_offset for t in trimmed_row] code_lists.append(trimmed_row) return code_lists def redistribute_codes(self, code_list): """Redistribute codes into layers for SNAC decoding.""" layer_1 = [] layer_2 = [] layer_3 = [] for i in range((len(code_list)+1)//7): if 7*i < len(code_list): layer_1.append(code_list[7*i]) if 7*i+1 < len(code_list): layer_2.append(code_list[7*i+1]-4096) if 7*i+2 < len(code_list): layer_3.append(code_list[7*i+2]-(2*4096)) if 7*i+3 < len(code_list): layer_3.append(code_list[7*i+3]-(3*4096)) if 7*i+4 < len(code_list): layer_2.append(code_list[7*i+4]-(4*4096)) if 7*i+5 < len(code_list): layer_3.append(code_list[7*i+5]-(5*4096)) if 7*i+6 < len(code_list): layer_3.append(code_list[7*i+6]-(6*4096)) codes = [ torch.tensor(layer_1).unsqueeze(0), torch.tensor(layer_2).unsqueeze(0), torch.tensor(layer_3).unsqueeze(0) ] # Move SNAC model to CPU before decoding self.snac_model.to("cpu") audio_hat = self.snac_model.decode(codes) return audio_hat def audio_to_wav_bytes(self, audio_tensor, sample_rate=24000): """Convert audio tensor to WAV bytes.""" audio_np = audio_tensor.detach().squeeze().to("cpu").numpy() # Normalize audio to int16 range audio_np = np.clip(audio_np, -1.0, 1.0) audio_int16 = (audio_np * 32767).astype(np.int16) # Create WAV file in memory wav_buffer = io.BytesIO() with wave.open(wav_buffer, 'wb') as wav_file: wav_file.setnchannels(1) # Mono wav_file.setsampwidth(2) # 2 bytes per sample (int16) wav_file.setframerate(sample_rate) wav_file.writeframes(audio_int16.tobytes()) wav_buffer.seek(0) return wav_buffer.getvalue() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Main inference function for TTS. Args: data (Dict[str, Any]): Input data containing: - inputs: Text string or list of text strings to synthesize - parameters (optional): Additional parameters like voice, generation settings Returns: List[Dict[str, Any]]: List containing audio data in base64 format """ try: # Extract inputs and parameters inputs = data.get("inputs", "") parameters = data.get("parameters", {}) if not inputs: return [{"error": "No input text provided"}] # Extract parameters voice = parameters.get("voice", None) generation_params = {k: v for k, v in parameters.items() if k != "voice"} # Preprocess text if isinstance(inputs, str): texts = [inputs] else: texts = inputs processed_texts = [self.preprocess_text(text, voice) for text in texts] # Prepare input IDs input_ids, attention_mask = self.prepare_input_ids(processed_texts) # Generate codes generated_ids = self.generate_codes(input_ids, attention_mask, **generation_params) # Process codes code_lists = self.process_generated_codes(generated_ids) # Generate audio for each input results = [] for i, code_list in enumerate(code_lists): try: # Generate audio audio_tensor = self.redistribute_codes(code_list) # Convert to WAV bytes wav_bytes = self.audio_to_wav_bytes(audio_tensor) # Encode to base64 audio_b64 = base64.b64encode(wav_bytes).decode('utf-8') results.append({ "text": texts[i] if i < len(texts) else processed_texts[i], "audio": audio_b64, "content_type": "audio/wav", "sample_rate": 24000 }) except Exception as e: results.append({ "text": texts[i] if i < len(texts) else processed_texts[i], "error": f"Audio generation failed: {str(e)}" }) return results except Exception as e: return [{"error": f"TTS processing failed: {str(e)}"}]