File size: 12,354 Bytes
e67f037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
from typing import Dict, List, Any
import torch
import re
import numpy as np
import io
import base64
from snac import SNAC
import wave
from unsloth import FastLanguageModel

class EndpointHandler():
    def __init__(self, path=""):
        
        # Load the model from Hugging Face using unsloth
        self.model, self.tokenizer = FastLanguageModel.from_pretrained(
            model_name="kosinebolisa/igbo-tts-model",  
            device_map="auto",
            load_in_4bit=False
        )
        
        # Enable inference mode for faster processing
        FastLanguageModel.for_inference(self.model)
        
        # Initialize SNAC model (you'll need to add the SNAC model loading here)
        # Assuming snac_model is available - you might need to load it separately
        self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
        self.snac_model = self.snac_model.to("cpu")
        self.snac_model.eval()  # Replace with actual SNAC loading
        
        # Igbo number word dictionary
        self.number_words = {
            0: "oroghoro", 1: "otu", 2: "abụọ", 3: "atọ", 4: "anọ", 5: "ise", 6: "isii", 7: "asaa", 8: "asato", 9: "itoolu",
            10: "iri", 11: "iri na otu", 12: "iri na abụọ", 13: "iri na atọ", 14: "iri na anọ", 15: "iri na ise",
            16: "iri na isii", 17: "iri na asaa", 18: "iri na asato", 19: "iri na iteghete",
            20: "iri abụọ", 30: "otuz", 40: "iri anọ", 50: "iri ise", 60: "iri isii", 70: "iri asaa",
            80: "iri asatọ", 90: "iri itoolu", 100: "otu narị", 1000: "otu puku"
        }
        
        # Special tokens
        self.start_token = torch.tensor([[128259]], dtype=torch.int64)  # Start of human
        self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)  # End of text, End of human
        self.pad_token_id = 128263
        self.eos_token_id = 128258
        self.code_start_token = 128257
        self.code_offset = 128266
        
    def number_to_words(self, number):
        """Convert numbers to Igbo words."""
        if number < 20:
            return self.number_words[number]
        elif number < 100:
            tens, unit = divmod(number, 10)
            return self.number_words[tens * 10] + (" na " + self.number_words[unit] if unit else "")
        elif number < 1000:
            hundreds, remainder = divmod(number, 100)
            base = (self.number_words[hundreds] + " narị") if hundreds > 1 else "otu narị"
            return base + (" na " + self.number_to_words(remainder) if remainder else "")
        elif number < 1_000_000:
            thousands, remainder = divmod(number, 1000)
            base = self.number_to_words(thousands) + " puku" if thousands > 1 else "otu puku"
            return base + (" na " + self.number_to_words(remainder) if remainder else "")
        elif number < 1_000_000_000:
            millions, remainder = divmod(number, 1_000_000)
            base = self.number_to_words(millions) + " nde"
            return base + (" na " + self.number_to_words(remainder) if remainder else "")
        elif number < 1_000_000_000_000:
            billions, remainder = divmod(number, 1_000_000_000)
            base = self.number_to_words(billions) + " ijeri"
            return base + (" na " + self.number_to_words(remainder) if remainder else "")
        else:
            return str(number)
    
    def replace_numbers_with_words(self, text):
        """Replace numbers in text with Igbo words."""
        def replace(match):
            number = int(match.group())
            return self.number_to_words(number)
        
        return re.sub(r'\b\d+\b', replace, text)
    
    def preprocess_text(self, text, voice=None):
        """Preprocess input text."""
        # Normalize numbers
        processed_text = self.replace_numbers_with_words(text)
        
        # Add voice prefix if specified
        if voice:
            processed_text = f"{voice}: {processed_text}"
            
        return processed_text
    
    def prepare_input_ids(self, texts):
        """Prepare input IDs for batch processing."""
        if isinstance(texts, str):
            texts = [texts]
            
        all_input_ids = []
        for text in texts:
            input_ids = self.tokenizer(text, return_tensors="pt").input_ids
            all_input_ids.append(input_ids)
        
        # Add special tokens and padding
        all_modified_input_ids = []
        for input_ids in all_input_ids:
            modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
            all_modified_input_ids.append(modified_input_ids)
        
        # Pad sequences
        max_length = max([ids.shape[1] for ids in all_modified_input_ids])
        all_padded_tensors = []
        all_attention_masks = []
        
        for modified_input_ids in all_modified_input_ids:
            padding = max_length - modified_input_ids.shape[1]
            padded_tensor = torch.cat([
                torch.full((1, padding), self.pad_token_id, dtype=torch.int64), 
                modified_input_ids
            ], dim=1)
            attention_mask = torch.cat([
                torch.zeros((1, padding), dtype=torch.int64), 
                torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64)
            ], dim=1)
            all_padded_tensors.append(padded_tensor)
            all_attention_masks.append(attention_mask)
        
        input_ids = torch.cat(all_padded_tensors, dim=0).to(self.model.device)
        attention_mask = torch.cat(all_attention_masks, dim=0).to(self.model.device)
        
        return input_ids, attention_mask
    
    def generate_codes(self, input_ids, attention_mask, **generation_params):
        """Generate audio codes using the language model."""
        default_params = {
            'max_new_tokens': 1200,
            'do_sample': True,
            'temperature': 0.6,
            'top_p': 0.95,
            'repetition_penalty': 1.1,
            'num_return_sequences': 1,
            'eos_token_id': self.eos_token_id,
            'use_cache': True
        }
        default_params.update(generation_params)
        
        generated_ids = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **default_params
        )
        
        return generated_ids
    
    def process_generated_codes(self, generated_ids):
        """Process generated token IDs to extract audio codes."""
        # Find the last occurrence of code start token
        token_indices = (generated_ids == self.code_start_token).nonzero(as_tuple=True)
        
        if len(token_indices[1]) > 0:
            last_occurrence_idx = token_indices[1][-1].item()
            cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
        else:
            cropped_tensor = generated_ids
        
        # Remove EOS tokens
        processed_rows = []
        for row in cropped_tensor:
            masked_row = row[row != self.eos_token_id]
            processed_rows.append(masked_row)
        
        # Convert to code lists
        code_lists = []
        for row in processed_rows:
            row_length = row.size(0)
            new_length = (row_length // 7) * 7
            trimmed_row = row[:new_length]
            trimmed_row = [t.item() - self.code_offset for t in trimmed_row]
            code_lists.append(trimmed_row)
        
        return code_lists
    
    def redistribute_codes(self, code_list):
        """Redistribute codes into layers for SNAC decoding."""
        layer_1 = []
        layer_2 = []
        layer_3 = []
        
        for i in range((len(code_list)+1)//7):
            if 7*i < len(code_list):
                layer_1.append(code_list[7*i])
            if 7*i+1 < len(code_list):
                layer_2.append(code_list[7*i+1]-4096)
            if 7*i+2 < len(code_list):
                layer_3.append(code_list[7*i+2]-(2*4096))
            if 7*i+3 < len(code_list):
                layer_3.append(code_list[7*i+3]-(3*4096))
            if 7*i+4 < len(code_list):
                layer_2.append(code_list[7*i+4]-(4*4096))
            if 7*i+5 < len(code_list):
                layer_3.append(code_list[7*i+5]-(5*4096))
            if 7*i+6 < len(code_list):
                layer_3.append(code_list[7*i+6]-(6*4096))
        
        codes = [
            torch.tensor(layer_1).unsqueeze(0),
            torch.tensor(layer_2).unsqueeze(0),
            torch.tensor(layer_3).unsqueeze(0)
        ]
        
        # Move SNAC model to CPU before decoding
        self.snac_model.to("cpu")
        audio_hat = self.snac_model.decode(codes)
        
        return audio_hat
    
    def audio_to_wav_bytes(self, audio_tensor, sample_rate=24000):
        """Convert audio tensor to WAV bytes."""
        audio_np = audio_tensor.detach().squeeze().to("cpu").numpy()
        
        # Normalize audio to int16 range
        audio_np = np.clip(audio_np, -1.0, 1.0)
        audio_int16 = (audio_np * 32767).astype(np.int16)
        
        # Create WAV file in memory
        wav_buffer = io.BytesIO()
        with wave.open(wav_buffer, 'wb') as wav_file:
            wav_file.setnchannels(1)  # Mono
            wav_file.setsampwidth(2)  # 2 bytes per sample (int16)
            wav_file.setframerate(sample_rate)
            wav_file.writeframes(audio_int16.tobytes())
        
        wav_buffer.seek(0)
        return wav_buffer.getvalue()
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Main inference function for TTS.
        
        Args:
            data (Dict[str, Any]): Input data containing:
                - inputs: Text string or list of text strings to synthesize
                - parameters (optional): Additional parameters like voice, generation settings
                
        Returns:
            List[Dict[str, Any]]: List containing audio data in base64 format
        """
        try:
            # Extract inputs and parameters
            inputs = data.get("inputs", "")
            parameters = data.get("parameters", {})
            
            if not inputs:
                return [{"error": "No input text provided"}]
            
            # Extract parameters
            voice = parameters.get("voice", None)
            generation_params = {k: v for k, v in parameters.items() if k != "voice"}
            
            # Preprocess text
            if isinstance(inputs, str):
                texts = [inputs]
            else:
                texts = inputs
                
            processed_texts = [self.preprocess_text(text, voice) for text in texts]
            
            # Prepare input IDs
            input_ids, attention_mask = self.prepare_input_ids(processed_texts)
            
            # Generate codes
            generated_ids = self.generate_codes(input_ids, attention_mask, **generation_params)
            
            # Process codes
            code_lists = self.process_generated_codes(generated_ids)
            
            # Generate audio for each input
            results = []
            for i, code_list in enumerate(code_lists):
                try:
                    # Generate audio
                    audio_tensor = self.redistribute_codes(code_list)
                    
                    # Convert to WAV bytes
                    wav_bytes = self.audio_to_wav_bytes(audio_tensor)
                    
                    # Encode to base64
                    audio_b64 = base64.b64encode(wav_bytes).decode('utf-8')
                    
                    results.append({
                        "text": texts[i] if i < len(texts) else processed_texts[i],
                        "audio": audio_b64,
                        "content_type": "audio/wav",
                        "sample_rate": 24000
                    })
                    
                except Exception as e:
                    results.append({
                        "text": texts[i] if i < len(texts) else processed_texts[i],
                        "error": f"Audio generation failed: {str(e)}"
                    })
            
            return results
            
        except Exception as e:
            return [{"error": f"TTS processing failed: {str(e)}"}]