KaniTTS / util.py
Den Pavloff
arr
5261d4e
raw
history blame
10.4 kB
import torch
import librosa
import requests
from nemo.collections.tts.models import AudioCodecModel
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
@dataclass
class Config:
model_name: str = "nineninesix/lfm-nano-codec-tts-exp-4-large-61468-st"
audiocodec_name: str = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps"
device_map: str = "auto"
tokeniser_length: int = 64400
start_of_text: int = 1
end_of_text: int = 2
max_new_tokens: int = 2000
temperature: float = .6
top_p: float = .95
repetition_penalty: float = 1.1
class NemoAudioPlayer:
def __init__(self, config, text_tokenizer_name: str = None) -> None:
self.conf = config
print(f"Loading NeMo codec model: {self.conf.audiocodec_name}")
# Load NeMo codec model
self.nemo_codec_model = AudioCodecModel.from_pretrained(
self.conf.audiocodec_name
).eval()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Moving NeMo codec to device: {self.device}")
self.nemo_codec_model.to(self.device)
self.text_tokenizer_name = text_tokenizer_name
if self.text_tokenizer_name:
self.tokenizer = AutoTokenizer.from_pretrained(self.text_tokenizer_name)
# Token configuration
self.tokeniser_length = self.conf.tokeniser_length
self.start_of_text = self.conf.start_of_text
self.end_of_text = self.conf.end_of_text
self.start_of_speech = self.tokeniser_length + 1
self.end_of_speech = self.tokeniser_length + 2
self.start_of_human = self.tokeniser_length + 3
self.end_of_human = self.tokeniser_length + 4
self.start_of_ai = self.tokeniser_length + 5
self.end_of_ai = self.tokeniser_length + 6
self.pad_token = self.tokeniser_length + 7
self.audio_tokens_start = self.tokeniser_length + 10
self.codebook_size = 4032
def output_validation(self, out_ids):
"""Validate that output contains required speech tokens"""
start_of_speech_flag = self.start_of_speech in out_ids
end_of_speech_flag = self.end_of_speech in out_ids
if not (start_of_speech_flag and end_of_speech_flag):
raise ValueError('Special speech tokens not found in output!')
def get_nano_codes(self, out_ids):
"""Extract nano codec tokens from model output"""
try:
start_a_idx = (out_ids == self.start_of_speech).nonzero(as_tuple=True)[0].item()
end_a_idx = (out_ids == self.end_of_speech).nonzero(as_tuple=True)[0].item()
except IndexError:
raise ValueError('Speech start/end tokens not found!')
if start_a_idx >= end_a_idx:
raise ValueError('Invalid audio codes sequence!')
audio_codes = out_ids[start_a_idx + 1: end_a_idx]
if len(audio_codes) % 4:
raise ValueError('Audio codes length must be multiple of 4!')
audio_codes = audio_codes.reshape(-1, 4)
# Decode audio codes
audio_codes = audio_codes - torch.tensor([self.codebook_size * i for i in range(4)])
audio_codes = audio_codes - self.audio_tokens_start
if (audio_codes < 0).sum().item() > 0:
raise ValueError('Invalid audio tokens detected!')
audio_codes = audio_codes.T.unsqueeze(0)
len_ = torch.tensor([audio_codes.shape[-1]])
return audio_codes, len_
def get_text(self, out_ids):
"""Extract text from model output"""
try:
start_t_idx = (out_ids == self.start_of_text).nonzero(as_tuple=True)[0].item()
end_t_idx = (out_ids == self.end_of_text).nonzero(as_tuple=True)[0].item()
except IndexError:
raise ValueError('Text start/end tokens not found!')
txt_tokens = out_ids[start_t_idx: end_t_idx + 1]
text = self.tokenizer.decode(txt_tokens, skip_special_tokens=True)
return text
def get_waveform(self, out_ids):
"""Convert model output to audio waveform"""
out_ids = out_ids.flatten()
# Validate output
self.output_validation(out_ids)
# Extract audio codes
audio_codes, len_ = self.get_nano_codes(out_ids)
audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device)
with torch.inference_mode():
reconstructed_audio, _ = self.nemo_codec_model.decode(
tokens=audio_codes,
tokens_len=len_
)
output_audio = reconstructed_audio.cpu().detach().numpy().squeeze()
if self.text_tokenizer_name:
text = self.get_text(out_ids)
return output_audio, text
else:
return output_audio, None
class KaniModel:
def __init__(self, config, player: NemoAudioPlayer, token: str) -> None:
self.conf = config
self.player = player
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading model: {self.conf.model_name}")
print(f"Target device: {self.device}")
# Load model with proper configuration
self.model = AutoModelForCausalLM.from_pretrained(
self.conf.model_name,
torch_dtype=torch.bfloat16,
device_map=self.conf.device_map,
token=token,
trust_remote_code=True # May be needed for some models
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.conf.model_name,
token=token,
trust_remote_code=True
)
print(f"Model loaded successfully on device: {next(self.model.parameters()).device}")
def get_input_ids(self, text_prompt: str) -> tuple[torch.tensor]:
"""Prepare input tokens for the model"""
START_OF_HUMAN = self.player.start_of_human
END_OF_TEXT = self.player.end_of_text
END_OF_HUMAN = self.player.end_of_human
# Tokenize input text
input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids
# Add special tokens
start_token = torch.tensor([[START_OF_HUMAN]], dtype=torch.int64)
end_tokens = torch.tensor([[END_OF_TEXT, END_OF_HUMAN]], dtype=torch.int64)
# Concatenate tokens
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64)
return modified_input_ids, attention_mask
def model_request(self, input_ids: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
"""Generate tokens using the model"""
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=self.conf.max_new_tokens,
do_sample=True,
temperature=self.conf.temperature,
top_p=self.conf.top_p,
repetition_penalty=self.conf.repetition_penalty,
num_return_sequences=1,
eos_token_id=self.player.end_of_speech,
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
)
return generated_ids.to('cpu')
def run_model(self, text: str):
"""Complete pipeline: text -> tokens -> generation -> audio"""
# Prepare input
input_ids, attention_mask = self.get_input_ids(text)
# Generate tokens
model_output = self.model_request(input_ids, attention_mask)
# Convert to audio
audio, _ = self.player.get_waveform(model_output)
return audio, text
class Demo:
def __init__(self):
self.audio_dir = './audio_examples'
os.makedirs(self.audio_dir, exist_ok=True)
self.sentences = [
"You make my days brighter, and my wildest dreams feel like reality. How do you do that?",
"Anyway, um, so, um, tell me, tell me all about her. I mean, what's she like? Is she really, you know, pretty?",
"Great, and just a couple quick questions so we can match you with the right buyer. Is your home address still 330 East Charleston Road?",
"No, that does not make you a failure. No, sweetie, no. It just, uh, it just means that you're having a tough time...",
"Oh, yeah. I mean did you want to get a quick snack together or maybe something before you go?",
"I-- Oh, I am such an idiot sometimes. I'm so sorry. Um, I-I don't know where my head's at.",
"Got it. $300,000. I can definitely help you get a very good price for your property by selecting a realtor.",
"Holy fu- Oh my God! Don't you understand how dangerous it is, huh?"
]
self.urls = [
'https://www.nineninesix.ai/examples/kani/1.wav',
'https://www.nineninesix.ai/examples/kani/2.wav',
'https://www.nineninesix.ai/examples/kani/5.wav',
'https://www.nineninesix.ai/examples/kani/6.wav',
'https://www.nineninesix.ai/examples/kani/3.wav',
'https://www.nineninesix.ai/examples/kani/7.wav',
'https://www.nineninesix.ai/examples/kani/4.wav',
'https://www.nineninesix.ai/examples/kani/8.wav'
]
def download_audio(self, url: str, filename: str):
filepath = os.path.join(self.audio_dir, filename)
if not os.path.exists(filepath):
r = requests.get(url)
r.raise_for_status()
with open(filepath, 'wb') as f:
f.write(r.content)
return filepath
def get_audio(self, filepath: str):
arr, _ = librosa.load(filepath, sr=22050)
return arr
def __call__(self):
examples = {}
for idx, (sentence, url) in enumerate(zip(self.sentences, self.urls), start=1):
filename = f"{idx}.wav"
filepath = self.download_audio(url, filename)
examples[sentence] = self.get_audio(filepath)
return examples