| import os |
| import re |
| import json |
| import torch |
| import inflect |
| import random |
| import uroman as ur |
| import numpy as np |
| import torchaudio |
| from transformers import AutoTokenizer |
| from outetts.wav_tokenizer.decoder import WavTokenizer |
| from outetts.wav_tokenizer.encoder.utils import convert_audio |
|
|
| class AudioTokenizer: |
|
|
| def __init__(self,tokenizer_path,wav_tokenizer_model_path,wav_tokenizer_config_path,): |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.text_prompt = "{bos}\n{text_start}{words}{text_end}\n{audio_start}\n" |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| self.bos = "<|im_start|>" |
| self.eos = "<|im_end|>" |
| self.input_length=0 |
| self.special_tokens = { |
| "audio_code": "<|{}|>", |
| "text_start": "<|text_start|>", |
| "text_end": "<|text_end|>", |
| "audio_start": "<|audio_start|>", |
| "audio_end": "<|audio_end|>", |
| "time": "<|t_{:.2f}|>", |
| "code_start": "<|code_start|>", |
| "code_end": "<|code_end|>", |
| "text_sep": "<|text_sep|>" |
| } |
| self.lec = inflect.engine() |
| |
| |
| |
| self.wavtokenizer = WavTokenizer.from_pretrained0802(wav_tokenizer_config_path, wav_tokenizer_model_path) |
| self.wavtokenizer = self.wavtokenizer.to(self.device) |
| self.BASE_DIR = os.path.dirname(__file__) |
| self.DEFAULT_SPEAKERS_DIR = os.path.join(self.BASE_DIR, "default_speakers") |
| self.speakers=["idera","emma","onye","jude","osagie","tayo","zainab","joke","regina","remi","umar","chinenye"] |
|
|
| def get_speaker_path(self,speaker_name): |
| return os.path.join(self.DEFAULT_SPEAKERS_DIR, f"{speaker_name}.json") |
|
|
| def load_speaker(self, path: str): |
| with open(path, "r") as f: |
| return json.load(f) |
|
|
| def load_default_speaker(self, name: str): |
| name = name.lower().strip() |
| speaker_path=self.get_speaker_path(name) |
| return self.load_speaker(speaker_path) |
|
|
|
|
| def process_text(self, text: str): |
|
|
| text = re.sub(r'\d+(\.\d+)?', lambda x: self.lec.number_to_words(x.group()), text.lower()) |
| text = re.sub(r'[-_/,\.\\]', ' ', text) |
| text = re.sub(r'[^a-z\s]', '', text) |
| text = re.sub(r'\s+', ' ', text).strip() |
| return text.split() |
|
|
| def create_audio_prompt(self,words: list) -> str: |
| prompt = [] |
| for i in words: |
| word = i["word"] |
| duration = self.special_tokens["time"].format(float(i["duration"])) |
| tokens = "".join([self.special_tokens["audio_code"].format(c) for c in i["codes"]]) |
| prompt.append(f'{word}{duration}{self.special_tokens["code_start"]}{tokens}{self.special_tokens["code_end"]}') |
| return "\n".join(prompt) |
|
|
| def create_prompt(self,text,speaker_name="idera"): |
| speaker=self.load_default_speaker(speaker_name) |
| input_words = self.process_text(speaker["text"]) + self.process_text(text) |
| |
|
|
| inputs_words_strings = f"{self.special_tokens['text_sep']}".join([i.strip() for i in input_words]) |
| prompt = self.text_prompt.format( |
| bos=self.bos, |
| text_start=self.special_tokens['text_start'], |
| words=inputs_words_strings, |
| text_end=self.special_tokens['text_end'], |
| audio_start=self.special_tokens['audio_start'] |
| ) |
| prompt += self.create_audio_prompt(speaker["words"]) |
|
|
| return prompt |
|
|
| def tokenize_prompt(self, prompt): |
| input_ids = self.tokenizer.encode( |
| prompt, |
| add_special_tokens=False, |
| return_tensors="pt" |
| ).to(self.device) |
| self.input_length=input_ids.shape[1] |
| return input_ids.to(self.device) |
|
|
|
|
| def get_audio(self,discrete_code): |
| discrete_code=torch.tensor([[discrete_code]]).to(self.device) |
| features = self.wavtokenizer.codes_to_features(discrete_code).to(self.device) |
| bandwidth_id = torch.tensor([0]).to(self.device) |
| audio_out = self.wavtokenizer.decode(features, bandwidth_id=bandwidth_id) |
| return audio_out.to("cpu") |
|
|
| def extract_integers(self,s): |
| |
| matches = re.findall(r'\|(-?\d+)\|', s) |
| |
| return [int(match) for match in matches] |
|
|
| def get_codes(self, output): |
| new_output=self.tokenizer.decode(output[0][self.input_length:]) |
| codes=self.extract_integers(new_output) |
| return codes |
|
|
|
|
| class AudioTokenizerForLocal(AudioTokenizer): |
|
|
| def __init__(self,tokenizer_path,wav_tokenizer_model_path,wav_tokenizer_config_path,): |
| super().__init__(tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path) |
| self.text_prompt = "{bos}\n{text_start}{words}{text_end}\n{lang}\n{audio_start}\n" |
| self.special_tokens = { |
| "audio_code": "<|{}|>", |
| "text_start": "<|text_start|>", |
| "text_end": "<|text_end|>", |
| "audio_start": "<|audio_start|>", |
| "audio_end": "<|audio_end|>", |
| "word_start": "<|word_start|>", |
| "word_end": "<|word_end|>", |
| "time": "<|t_{:.2f}|>", |
| "code_start": "<|code_start|>", |
| "code_end": "<|code_end|>", |
| "text_sep": "<|text_sep|>", |
| "hausa":"<|hausa|>", |
| "igbo":"<|igbo|>", |
| "yoruba":"<|yoruba|>", |
| } |
| self.uroman = ur.Uroman() |
| self.DEFAULT_SPEAKERS_DIR = os.path.join(self.BASE_DIR, "default_speakers_local") |
| self.speakers = [ |
| "hausa_male1", "hausa_male2","yoruba_male1", "yoruba_male2","igbo_male2" |
| "hausa_female1", "hausa_female2", "igbo_female1", "igbo_female2", "yoruba_female1", "yoruba_female2" |
| ] |
| |
| def process_text(self, text: str): |
| text = self.uroman.romanize_string(text) |
| text = re.sub(r'\d+(\.\d+)?', lambda x: self.lec.number_to_words(x.group()), text.lower()) |
| text = re.sub(r'[-_/,\.\\]', ' ', text) |
| text = re.sub(r'[^a-z\s]', '', text) |
| text = re.sub(r'\s+', ' ', text).strip() |
| return text.split() |
|
|
| def create_prompt(self,text,lang,speaker_name=None): |
| assert lang in ["hausa","igbo","yoruba"], f"Invalid language: {lang}, language must be one of ['hausa','igbo','yoruba']" |
| |
| if speaker_name is None: |
| if lang=="hausa": |
| speaker_name=random.choice(["hausa_male1","hausa_male2","hausa_female1","hausa_female2"]) |
| elif lang=="igbo": |
| speaker_name=random.choice(["igbo_female1","igbo_female2","igbo_male2"]) |
| else: |
| speaker_name=random.choice(["yoruba_male2","yoruba_female1","yoruba_female2"]) |
| speaker=self.load_default_speaker(speaker_name) |
| input_words = self.process_text(speaker["text"]) + self.process_text(text) |
| |
|
|
| inputs_words_strings = f"{self.special_tokens['text_sep']}".join([i.strip() for i in input_words]) |
| prompt = self.text_prompt.format( |
| bos=self.bos, |
| text_start=self.special_tokens['text_start'], |
| words=inputs_words_strings, |
| text_end=self.special_tokens['text_end'], |
| lang=self.special_tokens[lang], |
| audio_start=self.special_tokens['audio_start'] |
| ) |
| prompt += self.create_audio_prompt(speaker["words"]) |
|
|
| return prompt |
|
|
|
|
| class AudioTokenizerV2(AudioTokenizer): |
|
|
| def __init__(self,tokenizer_path,wav_tokenizer_model_path,wav_tokenizer_config_path,): |
| super().__init__(tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path) |
| self.text_prompt = "{bos}\n{text_start}{words}{text_end}\n{lang}\n{audio_start}\n" |
| self.asr_prompt="{bos}\n{code_start}{codes}{code_end}\n{asr}\n" |
| self.special_tokens = { |
| "audio_code": "<|{}|>", |
| "text_start": "<|text_start|>", |
| "text_end": "<|text_end|>", |
| "audio_start": "<|audio_start|>", |
| "audio_end": "<|audio_end|>", |
| "word_start": "<|word_start|>", |
| "word_end": "<|word_end|>", |
| "time": "<|t_{:.2f}|>", |
| "code_start": "<|code_start|>", |
| "code_end": "<|code_end|>", |
| "text_sep": "<|text_sep|>", |
| "hausa":"<|hausa|>", |
| "igbo":"<|igbo|>", |
| "yoruba":"<|yoruba|>", |
| "english":"<|english|>", |
| "asr":"<|asr|>" |
| } |
| self.uroman = ur.Uroman() |
| self.DEFAULT_SPEAKERS_DIR_LOCAL = os.path.join(self.BASE_DIR, "default_speakers_local") |
| self.DEFAULT_SPEAKERS_ENG = os.path.join(self.BASE_DIR, "default_speakers") |
| self.speakers_local = [ |
| "hausa_male1", "hausa_male2","yoruba_male1", "yoruba_male2","igbo_male2" |
| "hausa_female1", "hausa_female2", "igbo_female1", "igbo_female2", "yoruba_female1", "yoruba_female2" |
| ] |
| self.speakers_eng = ["idera","emma","onye","jude","osagie","tayo","zainab","joke","regina","remi","umar","chinenye","saheed"] |
| self.changed_tokens=[('<|1836|>', '<|453|><|453|>'), |
| ('<|1837|>', '<|1836|><|1836|>'), |
| ('<|1838|>', '<|1837|><|1837|>'), |
| ('<|1840|>', '<|244|><|167|>'), |
| ('<|1841|>', '<|235|><|219|>'), |
| ('<|1844|>', '<|453|><|244|>'), |
| ('<|1845|>', '<|1838|><|1838|>')] |
|
|
| def process_text(self, text: str): |
| text = self.uroman.romanize_string(text) |
| text = re.sub(r'\d+(\.\d+)?', lambda x: self.lec.number_to_words(x.group()), text.lower()) |
| text = re.sub(r'[-_/,\.\\]', ' ', text) |
| text = re.sub(r'[^a-z\s]', '', text) |
| text = re.sub(r'\s+', ' ', text).strip() |
| return text.split() |
|
|
| def get_speaker_path(self,speaker_name,dir): |
| return os.path.join(dir, f"{speaker_name}.json") |
|
|
| def load_speaker(self, path: str): |
| with open(path, "r") as f: |
| return json.load(f) |
|
|
| def load_default_speaker(self, name: str,dir: str): |
| name = name.lower().strip() |
| speaker_path=self.get_speaker_path(name,dir) |
| return self.load_speaker(speaker_path) |
|
|
| def create_prompt(self,text,lang,speaker_name=None): |
| assert lang in ["hausa","igbo","yoruba","english"], f"Invalid language: {lang}, language must be one of ['hausa','igbo','yoruba','english']" |
| |
| dir=self.DEFAULT_SPEAKERS_DIR_LOCAL |
| named_speakers = ["idera","emma","onye","jude","osagie","tayo","zainab","joke","regina","remi","umar","chinenye"] |
| if speaker_name is None: |
| if lang=="hausa": |
| speaker_name=random.choice(["hausa_male1","hausa_male2","hausa_female1","hausa_female2"]) |
| elif lang=="igbo": |
| speaker_name=random.choice(["igbo_female1","igbo_female2","igbo_male2"]) |
| elif lang=="yoruba": |
| speaker_name=random.choice(["yoruba_male2","yoruba_female1","yoruba_female2"]) |
| else: |
| speaker_name=random.choice(self.speakers_eng) |
|
|
| |
| if speaker_name.lower() in named_speakers or lang=="english": |
| dir=self.DEFAULT_SPEAKERS_ENG |
| speaker=self.load_default_speaker(speaker_name,dir) |
| input_words = self.process_text(speaker["text"]) + self.process_text(text) |
| |
|
|
| inputs_words_strings = f"{self.special_tokens['text_sep']}".join([i.strip() for i in input_words]) |
| prompt = self.text_prompt.format( |
| bos=self.bos, |
| text_start=self.special_tokens['text_start'], |
| words=inputs_words_strings, |
| text_end=self.special_tokens['text_end'], |
| lang=self.special_tokens[lang], |
| audio_start=self.special_tokens['audio_start'] |
| ) |
| prompt += self.create_audio_prompt(speaker["words"]) |
|
|
| return prompt |
| def replace_tokens(text): |
| for pair in self.changed_tokens: |
| text=text.replace(pair[0],pair[-1]) |
| return text |
|
|
| def resample(self,audio: np.ndarray, sr: int, target_sr: int): |
| audio = audio.to(dtype=torch.float32) |
| |
| audio = audio.unsqueeze(0) |
| |
| resampled = convert_audio(audio, sr, target_sr, 1) |
| return resampled.to(self.device ) |
|
|
| def quantize_wavtokenizer(self, path): |
| audio_data, sample_rate = torchaudio.load(path) |
| audio_data=audio_data.squeeze() |
| audio = self.resample(audio_data, sample_rate, 24000).to(self.device) |
| if audio.ndim==3: |
| audio=audio.squeeze(1) |
| bandwidth_id = torch.tensor([0]).to(self.device ) |
| _, codes = self.wavtokenizer.encode_infer(audio, bandwidth_id=bandwidth_id) |
| codes = codes.squeeze(1).to(self.device) |
| res="" |
| for code in codes[0].tolist(): |
| res+=f"<|{code}|>" |
| return res |
| |
| def create_asr_prompt(self,audio_path): |
| codes=self.quantize_wavtokenizer(audio_path) |
| prompt = self.asr_prompt.format( |
| bos=self.bos, |
| code_start=self.special_tokens['code_start'], |
| codes=codes, |
| code_end=self.special_tokens['code_end'], |
| asr=self.special_tokens["asr"], |
| ) |
| return prompt |
|
|
| def get_asr_results(self,output): |
| res="" |
| for text in self.tokenizer.decode(output[0]).split("<|text_start|>")[-1].split("<|text_end|>")[0].split("\n"): |
| res+=text.split("<|word_start|>")[-1].split("<|word_end|>")[0] |
| res+=" " |
| return res.strip() |
|
|