|
|
import torch |
|
|
import librosa |
|
|
import requests |
|
|
import time |
|
|
from nemo.collections.tts.models import AudioCodecModel |
|
|
from dataclasses import dataclass |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import os |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Config: |
|
|
model_name: str = "nineninesix/kani-tts-450m-0.1-pt" |
|
|
audiocodec_name: str = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps" |
|
|
device_map: str = "auto" |
|
|
tokeniser_length: int = 64400 |
|
|
start_of_text: int = 1 |
|
|
end_of_text: int = 2 |
|
|
max_new_tokens: int = 1200 |
|
|
temperature: float = 1.4 |
|
|
top_p: float = .95 |
|
|
repetition_penalty: float = 1.1 |
|
|
|
|
|
|
|
|
class NemoAudioPlayer: |
|
|
def __init__(self, config, text_tokenizer_name: str = None) -> None: |
|
|
self.conf = config |
|
|
print(f"Loading NeMo codec model: {self.conf.audiocodec_name}") |
|
|
|
|
|
|
|
|
self.nemo_codec_model = AudioCodecModel.from_pretrained( |
|
|
self.conf.audiocodec_name |
|
|
).eval() |
|
|
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
print(f"Moving NeMo codec to device: {self.device}") |
|
|
self.nemo_codec_model.to(self.device) |
|
|
|
|
|
self.text_tokenizer_name = text_tokenizer_name |
|
|
if self.text_tokenizer_name: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.text_tokenizer_name) |
|
|
|
|
|
|
|
|
self.tokeniser_length = self.conf.tokeniser_length |
|
|
self.start_of_text = self.conf.start_of_text |
|
|
self.end_of_text = self.conf.end_of_text |
|
|
self.start_of_speech = self.tokeniser_length + 1 |
|
|
self.end_of_speech = self.tokeniser_length + 2 |
|
|
self.start_of_human = self.tokeniser_length + 3 |
|
|
self.end_of_human = self.tokeniser_length + 4 |
|
|
self.start_of_ai = self.tokeniser_length + 5 |
|
|
self.end_of_ai = self.tokeniser_length + 6 |
|
|
self.pad_token = self.tokeniser_length + 7 |
|
|
self.audio_tokens_start = self.tokeniser_length + 10 |
|
|
self.codebook_size = 4032 |
|
|
|
|
|
def output_validation(self, out_ids): |
|
|
"""Validate that output contains required speech tokens""" |
|
|
start_of_speech_flag = self.start_of_speech in out_ids |
|
|
end_of_speech_flag = self.end_of_speech in out_ids |
|
|
|
|
|
if not (start_of_speech_flag and end_of_speech_flag): |
|
|
raise ValueError('Special speech tokens not found in output!') |
|
|
|
|
|
|
|
|
def get_nano_codes(self, out_ids): |
|
|
"""Extract nano codec tokens from model output""" |
|
|
try: |
|
|
start_a_idx = (out_ids == self.start_of_speech).nonzero(as_tuple=True)[0].item() |
|
|
end_a_idx = (out_ids == self.end_of_speech).nonzero(as_tuple=True)[0].item() |
|
|
except IndexError: |
|
|
raise ValueError('Speech start/end tokens not found!') |
|
|
|
|
|
if start_a_idx >= end_a_idx: |
|
|
raise ValueError('Invalid audio codes sequence!') |
|
|
|
|
|
audio_codes = out_ids[start_a_idx + 1: end_a_idx] |
|
|
|
|
|
if len(audio_codes) % 4: |
|
|
raise ValueError('Audio codes length must be multiple of 4!') |
|
|
|
|
|
audio_codes = audio_codes.reshape(-1, 4) |
|
|
|
|
|
|
|
|
audio_codes = audio_codes - torch.tensor([self.codebook_size * i for i in range(4)]) |
|
|
audio_codes = audio_codes - self.audio_tokens_start |
|
|
|
|
|
if (audio_codes < 0).sum().item() > 0: |
|
|
raise ValueError('Invalid audio tokens detected!') |
|
|
|
|
|
audio_codes = audio_codes.T.unsqueeze(0) |
|
|
len_ = torch.tensor([audio_codes.shape[-1]]) |
|
|
return audio_codes, len_ |
|
|
|
|
|
def get_text(self, out_ids): |
|
|
"""Extract text from model output""" |
|
|
try: |
|
|
start_t_idx = (out_ids == self.start_of_text).nonzero(as_tuple=True)[0].item() |
|
|
end_t_idx = (out_ids == self.end_of_text).nonzero(as_tuple=True)[0].item() |
|
|
except IndexError: |
|
|
raise ValueError('Text start/end tokens not found!') |
|
|
|
|
|
txt_tokens = out_ids[start_t_idx: end_t_idx + 1] |
|
|
text = self.tokenizer.decode(txt_tokens, skip_special_tokens=True) |
|
|
return text |
|
|
|
|
|
def get_waveform(self, out_ids): |
|
|
"""Convert model output to audio waveform""" |
|
|
out_ids = out_ids.flatten() |
|
|
|
|
|
|
|
|
self.output_validation(out_ids) |
|
|
|
|
|
|
|
|
audio_codes, len_ = self.get_nano_codes(out_ids) |
|
|
audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
reconstructed_audio, _ = self.nemo_codec_model.decode( |
|
|
tokens=audio_codes, |
|
|
tokens_len=len_ |
|
|
) |
|
|
output_audio = reconstructed_audio.cpu().detach().numpy().squeeze() |
|
|
|
|
|
if self.text_tokenizer_name: |
|
|
text = self.get_text(out_ids) |
|
|
return output_audio, text |
|
|
else: |
|
|
return output_audio, None |
|
|
|
|
|
|
|
|
class KaniModel: |
|
|
def __init__(self, config, player: NemoAudioPlayer, token: str) -> None: |
|
|
self.conf = config |
|
|
self.player = player |
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
print(f"Loading model: {self.conf.model_name}") |
|
|
print(f"Target device: {self.device}") |
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
self.conf.model_name, |
|
|
dtype=torch.bfloat16, |
|
|
device_map=self.conf.device_map, |
|
|
token=token, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.conf.model_name, |
|
|
token=token, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
print(f"Model loaded successfully on device: {next(self.model.parameters()).device}") |
|
|
|
|
|
def get_input_ids(self, text_prompt: str) -> tuple[torch.tensor]: |
|
|
"""Prepare input tokens for the model""" |
|
|
START_OF_HUMAN = self.player.start_of_human |
|
|
END_OF_TEXT = self.player.end_of_text |
|
|
END_OF_HUMAN = self.player.end_of_human |
|
|
|
|
|
|
|
|
input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids |
|
|
|
|
|
|
|
|
start_token = torch.tensor([[START_OF_HUMAN]], dtype=torch.int64) |
|
|
end_tokens = torch.tensor([[END_OF_TEXT, END_OF_HUMAN]], dtype=torch.int64) |
|
|
|
|
|
|
|
|
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) |
|
|
attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64) |
|
|
return modified_input_ids, attention_mask |
|
|
|
|
|
def model_request( |
|
|
self, |
|
|
input_ids: torch.tensor, |
|
|
attention_mask: torch.tensor, |
|
|
t:float, |
|
|
top_p:float, |
|
|
rp: float, |
|
|
max_tok: int) -> torch.tensor: |
|
|
"""Generate tokens using the model""" |
|
|
input_ids = input_ids.to(self.device) |
|
|
attention_mask = attention_mask.to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = self.model.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
max_new_tokens=max_tok, |
|
|
do_sample=True, |
|
|
temperature=t, |
|
|
top_p=top_p, |
|
|
repetition_penalty=rp, |
|
|
num_return_sequences=1, |
|
|
eos_token_id=self.player.end_of_speech, |
|
|
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id |
|
|
) |
|
|
return generated_ids.to('cpu') |
|
|
|
|
|
def time_report(self, point_1, point_2, point_3): |
|
|
model_request = point_2 - point_1 |
|
|
player_time = point_3 - point_2 |
|
|
total_time = point_3 - point_1 |
|
|
report = f"SPEECH TOKENS: {model_request:.2f}\nCODEC: {player_time:.2f}\nTOTAL: {total_time:.2f}" |
|
|
return report |
|
|
|
|
|
def run_model(self, text: str, t: float, top_p: float, rp: float, max_tok: int): |
|
|
"""Complete pipeline: text -> tokens -> generation -> audio""" |
|
|
|
|
|
input_ids, attention_mask = self.get_input_ids(text) |
|
|
|
|
|
|
|
|
point_1 = time.time() |
|
|
model_output = self.model_request(input_ids, attention_mask, t, top_p, rp, max_tok) |
|
|
|
|
|
|
|
|
point_2 = time.time() |
|
|
audio, _ = self.player.get_waveform(model_output) |
|
|
|
|
|
point_3 = time.time() |
|
|
return audio, text, self.time_report(point_1, point_2, point_3) |
|
|
|
|
|
|