yarngpt-custom / audiotokenizer.py
laztopaz's picture
Upload audiotokenizer.py with huggingface_hub
9a5de2b verified
import os
import re
import json
import torch
import inflect
import random
import uroman as ur
import numpy as np
import torchaudio
from transformers import AutoTokenizer
from outetts.wav_tokenizer.decoder import WavTokenizer
from outetts.wav_tokenizer.encoder.utils import convert_audio
class AudioTokenizer:
def __init__(self,tokenizer_path,wav_tokenizer_model_path,wav_tokenizer_config_path,):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.text_prompt = "{bos}\n{text_start}{words}{text_end}\n{audio_start}\n"
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.bos = "<|im_start|>"
self.eos = "<|im_end|>"
self.input_length=0
self.special_tokens = {
"audio_code": "<|{}|>",
"text_start": "<|text_start|>",
"text_end": "<|text_end|>",
"audio_start": "<|audio_start|>",
"audio_end": "<|audio_end|>",
"time": "<|t_{:.2f}|>",
"code_start": "<|code_start|>",
"code_end": "<|code_end|>",
"text_sep": "<|text_sep|>"
}
self.lec = inflect.engine()
#self.text_prompt = "{bos}\n{text_start}{words}{text_end}\n{audio_start}\n"
#self.config_path = "/content/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
#self.model_path = "/content/wavtokenizer_large_speech_320_24k.ckpt"
self.wavtokenizer = WavTokenizer.from_pretrained0802(wav_tokenizer_config_path, wav_tokenizer_model_path)
self.wavtokenizer = self.wavtokenizer.to(self.device)
self.BASE_DIR = os.path.dirname(__file__)
self.DEFAULT_SPEAKERS_DIR = os.path.join(self.BASE_DIR, "default_speakers")
self.speakers=["idera","emma","onye","jude","osagie","tayo","zainab","joke","regina","remi","umar","chinenye"]
def get_speaker_path(self,speaker_name):
return os.path.join(self.DEFAULT_SPEAKERS_DIR, f"{speaker_name}.json")
def load_speaker(self, path: str):
with open(path, "r") as f:
return json.load(f)
def load_default_speaker(self, name: str):
name = name.lower().strip()
speaker_path=self.get_speaker_path(name)
return self.load_speaker(speaker_path)
def process_text(self, text: str):
text = re.sub(r'\d+(\.\d+)?', lambda x: self.lec.number_to_words(x.group()), text.lower())
text = re.sub(r'[-_/,\.\\]', ' ', text)
text = re.sub(r'[^a-z\s]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text.split()
def create_audio_prompt(self,words: list) -> str:
prompt = []
for i in words:
word = i["word"]
duration = self.special_tokens["time"].format(float(i["duration"]))
tokens = "".join([self.special_tokens["audio_code"].format(c) for c in i["codes"]])
prompt.append(f'{word}{duration}{self.special_tokens["code_start"]}{tokens}{self.special_tokens["code_end"]}')
return "\n".join(prompt)
def create_prompt(self,text,speaker_name="idera"):
speaker=self.load_default_speaker(speaker_name)
input_words = self.process_text(speaker["text"]) + self.process_text(text)
#input_words = process_text(speaker["text"]) + input_words
inputs_words_strings = f"{self.special_tokens['text_sep']}".join([i.strip() for i in input_words])
prompt = self.text_prompt.format(
bos=self.bos,
text_start=self.special_tokens['text_start'],
words=inputs_words_strings,
text_end=self.special_tokens['text_end'],
audio_start=self.special_tokens['audio_start']
)
prompt += self.create_audio_prompt(speaker["words"])
return prompt
def tokenize_prompt(self, prompt):
input_ids = self.tokenizer.encode(
prompt,
add_special_tokens=False,
return_tensors="pt"
).to(self.device)
self.input_length=input_ids.shape[1]
return input_ids.to(self.device)
def get_audio(self,discrete_code):
discrete_code=torch.tensor([[discrete_code]]).to(self.device)
features = self.wavtokenizer.codes_to_features(discrete_code).to(self.device)
bandwidth_id = torch.tensor([0]).to(self.device)
audio_out = self.wavtokenizer.decode(features, bandwidth_id=bandwidth_id)
return audio_out.to("cpu")
def extract_integers(self,s):
# Match integers enclosed in vertical bars |integer|
matches = re.findall(r'\|(-?\d+)\|', s)
# Convert matches to integers
return [int(match) for match in matches]
def get_codes(self, output):
new_output=self.tokenizer.decode(output[0][self.input_length:])
codes=self.extract_integers(new_output)
return codes
class AudioTokenizerForLocal(AudioTokenizer):
def __init__(self,tokenizer_path,wav_tokenizer_model_path,wav_tokenizer_config_path,):
super().__init__(tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path)
self.text_prompt = "{bos}\n{text_start}{words}{text_end}\n{lang}\n{audio_start}\n"
self.special_tokens = {
"audio_code": "<|{}|>",
"text_start": "<|text_start|>",
"text_end": "<|text_end|>",
"audio_start": "<|audio_start|>",
"audio_end": "<|audio_end|>",
"word_start": "<|word_start|>",
"word_end": "<|word_end|>",
"time": "<|t_{:.2f}|>",
"code_start": "<|code_start|>",
"code_end": "<|code_end|>",
"text_sep": "<|text_sep|>",
"hausa":"<|hausa|>",
"igbo":"<|igbo|>",
"yoruba":"<|yoruba|>",
}
self.uroman = ur.Uroman()
self.DEFAULT_SPEAKERS_DIR = os.path.join(self.BASE_DIR, "default_speakers_local")
self.speakers = [
"hausa_male1", "hausa_male2","yoruba_male1", "yoruba_male2","igbo_male2" #"igbo_male1", "igbo_male2",
"hausa_female1", "hausa_female2", "igbo_female1", "igbo_female2", "yoruba_female1", "yoruba_female2"
]
def process_text(self, text: str):
text = self.uroman.romanize_string(text)
text = re.sub(r'\d+(\.\d+)?', lambda x: self.lec.number_to_words(x.group()), text.lower())
text = re.sub(r'[-_/,\.\\]', ' ', text)
text = re.sub(r'[^a-z\s]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text.split()
def create_prompt(self,text,lang,speaker_name=None):
assert lang in ["hausa","igbo","yoruba"], f"Invalid language: {lang}, language must be one of ['hausa','igbo','yoruba']"
#if no speaker
if speaker_name is None:
if lang=="hausa":
speaker_name=random.choice(["hausa_male1","hausa_male2","hausa_female1","hausa_female2"])
elif lang=="igbo":
speaker_name=random.choice(["igbo_female1","igbo_female2","igbo_male2"])#"igbo_male1"])
else:
speaker_name=random.choice(["yoruba_male2","yoruba_female1","yoruba_female2"])
speaker=self.load_default_speaker(speaker_name)
input_words = self.process_text(speaker["text"]) + self.process_text(text)
#input_words = process_text(speaker["text"]) + input_words
inputs_words_strings = f"{self.special_tokens['text_sep']}".join([i.strip() for i in input_words])
prompt = self.text_prompt.format(
bos=self.bos,
text_start=self.special_tokens['text_start'],
words=inputs_words_strings,
text_end=self.special_tokens['text_end'],
lang=self.special_tokens[lang],
audio_start=self.special_tokens['audio_start']
)
prompt += self.create_audio_prompt(speaker["words"])
return prompt
class AudioTokenizerV2(AudioTokenizer):
def __init__(self,tokenizer_path,wav_tokenizer_model_path,wav_tokenizer_config_path,):
super().__init__(tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path)
self.text_prompt = "{bos}\n{text_start}{words}{text_end}\n{lang}\n{audio_start}\n"
self.asr_prompt="{bos}\n{code_start}{codes}{code_end}\n{asr}\n"
self.special_tokens = {
"audio_code": "<|{}|>",
"text_start": "<|text_start|>",
"text_end": "<|text_end|>",
"audio_start": "<|audio_start|>",
"audio_end": "<|audio_end|>",
"word_start": "<|word_start|>",
"word_end": "<|word_end|>",
"time": "<|t_{:.2f}|>",
"code_start": "<|code_start|>",
"code_end": "<|code_end|>",
"text_sep": "<|text_sep|>",
"hausa":"<|hausa|>",
"igbo":"<|igbo|>",
"yoruba":"<|yoruba|>",
"english":"<|english|>",#<|english|>
"asr":"<|asr|>"
}
self.uroman = ur.Uroman()
self.DEFAULT_SPEAKERS_DIR_LOCAL = os.path.join(self.BASE_DIR, "default_speakers_local")
self.DEFAULT_SPEAKERS_ENG = os.path.join(self.BASE_DIR, "default_speakers")
self.speakers_local = [
"hausa_male1", "hausa_male2","yoruba_male1", "yoruba_male2","igbo_male2" #"igbo_male1", "igbo_male2",
"hausa_female1", "hausa_female2", "igbo_female1", "igbo_female2", "yoruba_female1", "yoruba_female2"
]
self.speakers_eng = ["idera","emma","onye","jude","osagie","tayo","zainab","joke","regina","remi","umar","chinenye","saheed"]
self.changed_tokens=[('<|1836|>', '<|453|><|453|>'),
('<|1837|>', '<|1836|><|1836|>'),
('<|1838|>', '<|1837|><|1837|>'),
('<|1840|>', '<|244|><|167|>'),
('<|1841|>', '<|235|><|219|>'),
('<|1844|>', '<|453|><|244|>'),
('<|1845|>', '<|1838|><|1838|>')]
def process_text(self, text: str):
text = self.uroman.romanize_string(text)
text = re.sub(r'\d+(\.\d+)?', lambda x: self.lec.number_to_words(x.group()), text.lower())
text = re.sub(r'[-_/,\.\\]', ' ', text)
text = re.sub(r'[^a-z\s]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text.split()
def get_speaker_path(self,speaker_name,dir):
return os.path.join(dir, f"{speaker_name}.json")
def load_speaker(self, path: str):
with open(path, "r") as f:
return json.load(f)
def load_default_speaker(self, name: str,dir: str):
name = name.lower().strip()
speaker_path=self.get_speaker_path(name,dir)
return self.load_speaker(speaker_path)
def create_prompt(self,text,lang,speaker_name=None):
assert lang in ["hausa","igbo","yoruba","english"], f"Invalid language: {lang}, language must be one of ['hausa','igbo','yoruba','english']"
#if no speaker
dir=self.DEFAULT_SPEAKERS_DIR_LOCAL
named_speakers = ["idera","emma","onye","jude","osagie","tayo","zainab","joke","regina","remi","umar","chinenye"]
if speaker_name is None:
if lang=="hausa":
speaker_name=random.choice(["hausa_male1","hausa_male2","hausa_female1","hausa_female2"])
elif lang=="igbo":
speaker_name=random.choice(["igbo_female1","igbo_female2","igbo_male2"])#"igbo_male1"])
elif lang=="yoruba":
speaker_name=random.choice(["yoruba_male2","yoruba_female1","yoruba_female2"])
else:
speaker_name=random.choice(self.speakers_eng)
# Use DEFAULT_SPEAKERS_ENG for named speakers (tayo, idera, etc.)
if speaker_name.lower() in named_speakers or lang=="english":
dir=self.DEFAULT_SPEAKERS_ENG
speaker=self.load_default_speaker(speaker_name,dir)
input_words = self.process_text(speaker["text"]) + self.process_text(text)
#input_words = process_text(speaker["text"]) + input_words
inputs_words_strings = f"{self.special_tokens['text_sep']}".join([i.strip() for i in input_words])
prompt = self.text_prompt.format(
bos=self.bos,
text_start=self.special_tokens['text_start'],
words=inputs_words_strings,
text_end=self.special_tokens['text_end'],
lang=self.special_tokens[lang],
audio_start=self.special_tokens['audio_start']
)
prompt += self.create_audio_prompt(speaker["words"])
return prompt
def replace_tokens(text):
for pair in self.changed_tokens:
text=text.replace(pair[0],pair[-1])
return text
def resample(self,audio: np.ndarray, sr: int, target_sr: int):
audio = audio.to(dtype=torch.float32)
#.clone().detach()
audio = audio.unsqueeze(0)
# 1 as last arg corresponds to mono audio
resampled = convert_audio(audio, sr, target_sr, 1)
return resampled.to(self.device )
def quantize_wavtokenizer(self, path):
audio_data, sample_rate = torchaudio.load(path)
audio_data=audio_data.squeeze()
audio = self.resample(audio_data, sample_rate, 24000).to(self.device)
if audio.ndim==3:
audio=audio.squeeze(1)
bandwidth_id = torch.tensor([0]).to(self.device )
_, codes = self.wavtokenizer.encode_infer(audio, bandwidth_id=bandwidth_id)
codes = codes.squeeze(1).to(self.device)#+last_text_token
res=""
for code in codes[0].tolist():
res+=f"<|{code}|>"
return res
def create_asr_prompt(self,audio_path):
codes=self.quantize_wavtokenizer(audio_path)
prompt = self.asr_prompt.format(
bos=self.bos,
code_start=self.special_tokens['code_start'],
codes=codes,
code_end=self.special_tokens['code_end'],
asr=self.special_tokens["asr"],
)
return prompt
def get_asr_results(self,output):
res=""
for text in self.tokenizer.decode(output[0]).split("<|text_start|>")[-1].split("<|text_end|>")[0].split("\n"):
res+=text.split("<|word_start|>")[-1].split("<|word_end|>")[0]
res+=" "
return res.strip()