Nebolisa Kosiso commited on
Commit
e67f037
·
1 Parent(s): 6cffb78

add custom handler

Browse files
__pycache__/handler.cpython-313.pyc ADDED
Binary file (9.22 kB). View file
 
handler.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ import re
4
+ import numpy as np
5
+ import io
6
+ import base64
7
+ from snac import SNAC
8
+ import wave
9
+ from unsloth import FastLanguageModel
10
+
11
+ class EndpointHandler():
12
+ def __init__(self, path=""):
13
+
14
+ # Load the model from Hugging Face using unsloth
15
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
16
+ model_name="kosinebolisa/igbo-tts-model",
17
+ device_map="auto",
18
+ load_in_4bit=False
19
+ )
20
+
21
+ # Enable inference mode for faster processing
22
+ FastLanguageModel.for_inference(self.model)
23
+
24
+ # Initialize SNAC model (you'll need to add the SNAC model loading here)
25
+ # Assuming snac_model is available - you might need to load it separately
26
+ self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
27
+ self.snac_model = self.snac_model.to("cpu")
28
+ self.snac_model.eval() # Replace with actual SNAC loading
29
+
30
+ # Igbo number word dictionary
31
+ self.number_words = {
32
+ 0: "oroghoro", 1: "otu", 2: "abụọ", 3: "atọ", 4: "anọ", 5: "ise", 6: "isii", 7: "asaa", 8: "asato", 9: "itoolu",
33
+ 10: "iri", 11: "iri na otu", 12: "iri na abụọ", 13: "iri na atọ", 14: "iri na anọ", 15: "iri na ise",
34
+ 16: "iri na isii", 17: "iri na asaa", 18: "iri na asato", 19: "iri na iteghete",
35
+ 20: "iri abụọ", 30: "otuz", 40: "iri anọ", 50: "iri ise", 60: "iri isii", 70: "iri asaa",
36
+ 80: "iri asatọ", 90: "iri itoolu", 100: "otu narị", 1000: "otu puku"
37
+ }
38
+
39
+ # Special tokens
40
+ self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
41
+ self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
42
+ self.pad_token_id = 128263
43
+ self.eos_token_id = 128258
44
+ self.code_start_token = 128257
45
+ self.code_offset = 128266
46
+
47
+ def number_to_words(self, number):
48
+ """Convert numbers to Igbo words."""
49
+ if number < 20:
50
+ return self.number_words[number]
51
+ elif number < 100:
52
+ tens, unit = divmod(number, 10)
53
+ return self.number_words[tens * 10] + (" na " + self.number_words[unit] if unit else "")
54
+ elif number < 1000:
55
+ hundreds, remainder = divmod(number, 100)
56
+ base = (self.number_words[hundreds] + " narị") if hundreds > 1 else "otu narị"
57
+ return base + (" na " + self.number_to_words(remainder) if remainder else "")
58
+ elif number < 1_000_000:
59
+ thousands, remainder = divmod(number, 1000)
60
+ base = self.number_to_words(thousands) + " puku" if thousands > 1 else "otu puku"
61
+ return base + (" na " + self.number_to_words(remainder) if remainder else "")
62
+ elif number < 1_000_000_000:
63
+ millions, remainder = divmod(number, 1_000_000)
64
+ base = self.number_to_words(millions) + " nde"
65
+ return base + (" na " + self.number_to_words(remainder) if remainder else "")
66
+ elif number < 1_000_000_000_000:
67
+ billions, remainder = divmod(number, 1_000_000_000)
68
+ base = self.number_to_words(billions) + " ijeri"
69
+ return base + (" na " + self.number_to_words(remainder) if remainder else "")
70
+ else:
71
+ return str(number)
72
+
73
+ def replace_numbers_with_words(self, text):
74
+ """Replace numbers in text with Igbo words."""
75
+ def replace(match):
76
+ number = int(match.group())
77
+ return self.number_to_words(number)
78
+
79
+ return re.sub(r'\b\d+\b', replace, text)
80
+
81
+ def preprocess_text(self, text, voice=None):
82
+ """Preprocess input text."""
83
+ # Normalize numbers
84
+ processed_text = self.replace_numbers_with_words(text)
85
+
86
+ # Add voice prefix if specified
87
+ if voice:
88
+ processed_text = f"{voice}: {processed_text}"
89
+
90
+ return processed_text
91
+
92
+ def prepare_input_ids(self, texts):
93
+ """Prepare input IDs for batch processing."""
94
+ if isinstance(texts, str):
95
+ texts = [texts]
96
+
97
+ all_input_ids = []
98
+ for text in texts:
99
+ input_ids = self.tokenizer(text, return_tensors="pt").input_ids
100
+ all_input_ids.append(input_ids)
101
+
102
+ # Add special tokens and padding
103
+ all_modified_input_ids = []
104
+ for input_ids in all_input_ids:
105
+ modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
106
+ all_modified_input_ids.append(modified_input_ids)
107
+
108
+ # Pad sequences
109
+ max_length = max([ids.shape[1] for ids in all_modified_input_ids])
110
+ all_padded_tensors = []
111
+ all_attention_masks = []
112
+
113
+ for modified_input_ids in all_modified_input_ids:
114
+ padding = max_length - modified_input_ids.shape[1]
115
+ padded_tensor = torch.cat([
116
+ torch.full((1, padding), self.pad_token_id, dtype=torch.int64),
117
+ modified_input_ids
118
+ ], dim=1)
119
+ attention_mask = torch.cat([
120
+ torch.zeros((1, padding), dtype=torch.int64),
121
+ torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64)
122
+ ], dim=1)
123
+ all_padded_tensors.append(padded_tensor)
124
+ all_attention_masks.append(attention_mask)
125
+
126
+ input_ids = torch.cat(all_padded_tensors, dim=0).to(self.model.device)
127
+ attention_mask = torch.cat(all_attention_masks, dim=0).to(self.model.device)
128
+
129
+ return input_ids, attention_mask
130
+
131
+ def generate_codes(self, input_ids, attention_mask, **generation_params):
132
+ """Generate audio codes using the language model."""
133
+ default_params = {
134
+ 'max_new_tokens': 1200,
135
+ 'do_sample': True,
136
+ 'temperature': 0.6,
137
+ 'top_p': 0.95,
138
+ 'repetition_penalty': 1.1,
139
+ 'num_return_sequences': 1,
140
+ 'eos_token_id': self.eos_token_id,
141
+ 'use_cache': True
142
+ }
143
+ default_params.update(generation_params)
144
+
145
+ generated_ids = self.model.generate(
146
+ input_ids=input_ids,
147
+ attention_mask=attention_mask,
148
+ **default_params
149
+ )
150
+
151
+ return generated_ids
152
+
153
+ def process_generated_codes(self, generated_ids):
154
+ """Process generated token IDs to extract audio codes."""
155
+ # Find the last occurrence of code start token
156
+ token_indices = (generated_ids == self.code_start_token).nonzero(as_tuple=True)
157
+
158
+ if len(token_indices[1]) > 0:
159
+ last_occurrence_idx = token_indices[1][-1].item()
160
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
161
+ else:
162
+ cropped_tensor = generated_ids
163
+
164
+ # Remove EOS tokens
165
+ processed_rows = []
166
+ for row in cropped_tensor:
167
+ masked_row = row[row != self.eos_token_id]
168
+ processed_rows.append(masked_row)
169
+
170
+ # Convert to code lists
171
+ code_lists = []
172
+ for row in processed_rows:
173
+ row_length = row.size(0)
174
+ new_length = (row_length // 7) * 7
175
+ trimmed_row = row[:new_length]
176
+ trimmed_row = [t.item() - self.code_offset for t in trimmed_row]
177
+ code_lists.append(trimmed_row)
178
+
179
+ return code_lists
180
+
181
+ def redistribute_codes(self, code_list):
182
+ """Redistribute codes into layers for SNAC decoding."""
183
+ layer_1 = []
184
+ layer_2 = []
185
+ layer_3 = []
186
+
187
+ for i in range((len(code_list)+1)//7):
188
+ if 7*i < len(code_list):
189
+ layer_1.append(code_list[7*i])
190
+ if 7*i+1 < len(code_list):
191
+ layer_2.append(code_list[7*i+1]-4096)
192
+ if 7*i+2 < len(code_list):
193
+ layer_3.append(code_list[7*i+2]-(2*4096))
194
+ if 7*i+3 < len(code_list):
195
+ layer_3.append(code_list[7*i+3]-(3*4096))
196
+ if 7*i+4 < len(code_list):
197
+ layer_2.append(code_list[7*i+4]-(4*4096))
198
+ if 7*i+5 < len(code_list):
199
+ layer_3.append(code_list[7*i+5]-(5*4096))
200
+ if 7*i+6 < len(code_list):
201
+ layer_3.append(code_list[7*i+6]-(6*4096))
202
+
203
+ codes = [
204
+ torch.tensor(layer_1).unsqueeze(0),
205
+ torch.tensor(layer_2).unsqueeze(0),
206
+ torch.tensor(layer_3).unsqueeze(0)
207
+ ]
208
+
209
+ # Move SNAC model to CPU before decoding
210
+ self.snac_model.to("cpu")
211
+ audio_hat = self.snac_model.decode(codes)
212
+
213
+ return audio_hat
214
+
215
+ def audio_to_wav_bytes(self, audio_tensor, sample_rate=24000):
216
+ """Convert audio tensor to WAV bytes."""
217
+ audio_np = audio_tensor.detach().squeeze().to("cpu").numpy()
218
+
219
+ # Normalize audio to int16 range
220
+ audio_np = np.clip(audio_np, -1.0, 1.0)
221
+ audio_int16 = (audio_np * 32767).astype(np.int16)
222
+
223
+ # Create WAV file in memory
224
+ wav_buffer = io.BytesIO()
225
+ with wave.open(wav_buffer, 'wb') as wav_file:
226
+ wav_file.setnchannels(1) # Mono
227
+ wav_file.setsampwidth(2) # 2 bytes per sample (int16)
228
+ wav_file.setframerate(sample_rate)
229
+ wav_file.writeframes(audio_int16.tobytes())
230
+
231
+ wav_buffer.seek(0)
232
+ return wav_buffer.getvalue()
233
+
234
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
235
+ """
236
+ Main inference function for TTS.
237
+
238
+ Args:
239
+ data (Dict[str, Any]): Input data containing:
240
+ - inputs: Text string or list of text strings to synthesize
241
+ - parameters (optional): Additional parameters like voice, generation settings
242
+
243
+ Returns:
244
+ List[Dict[str, Any]]: List containing audio data in base64 format
245
+ """
246
+ try:
247
+ # Extract inputs and parameters
248
+ inputs = data.get("inputs", "")
249
+ parameters = data.get("parameters", {})
250
+
251
+ if not inputs:
252
+ return [{"error": "No input text provided"}]
253
+
254
+ # Extract parameters
255
+ voice = parameters.get("voice", None)
256
+ generation_params = {k: v for k, v in parameters.items() if k != "voice"}
257
+
258
+ # Preprocess text
259
+ if isinstance(inputs, str):
260
+ texts = [inputs]
261
+ else:
262
+ texts = inputs
263
+
264
+ processed_texts = [self.preprocess_text(text, voice) for text in texts]
265
+
266
+ # Prepare input IDs
267
+ input_ids, attention_mask = self.prepare_input_ids(processed_texts)
268
+
269
+ # Generate codes
270
+ generated_ids = self.generate_codes(input_ids, attention_mask, **generation_params)
271
+
272
+ # Process codes
273
+ code_lists = self.process_generated_codes(generated_ids)
274
+
275
+ # Generate audio for each input
276
+ results = []
277
+ for i, code_list in enumerate(code_lists):
278
+ try:
279
+ # Generate audio
280
+ audio_tensor = self.redistribute_codes(code_list)
281
+
282
+ # Convert to WAV bytes
283
+ wav_bytes = self.audio_to_wav_bytes(audio_tensor)
284
+
285
+ # Encode to base64
286
+ audio_b64 = base64.b64encode(wav_bytes).decode('utf-8')
287
+
288
+ results.append({
289
+ "text": texts[i] if i < len(texts) else processed_texts[i],
290
+ "audio": audio_b64,
291
+ "content_type": "audio/wav",
292
+ "sample_rate": 24000
293
+ })
294
+
295
+ except Exception as e:
296
+ results.append({
297
+ "text": texts[i] if i < len(texts) else processed_texts[i],
298
+ "error": f"Audio generation failed: {str(e)}"
299
+ })
300
+
301
+ return results
302
+
303
+ except Exception as e:
304
+ return [{"error": f"TTS processing failed: {str(e)}"}]
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ unsloth
3
+ snac
4
+ transformers
5
+ torchaudio
6
+ uvicorn
7
+ fastapi
8
+ python-multipart
9
+ numpy
10
+ datasets
11
+ soundfile