File size: 4,169 Bytes
308155b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import os
import torch
import torchaudio
import pandas as pd
from tqdm import tqdm
from src.chatterbox_.tts import ChatterboxTTS, punc_norm
from src.chatterbox_.models.s3tokenizer import S3_SR
from src.utils import setup_logger
logger = setup_logger(__name__)
def preprocess_dataset_ljspeech(config, tts_engine: ChatterboxTTS):
data = pd.read_csv(config.csv_path, sep="|", header=None, quoting=3)
os.makedirs(config.preprocessed_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tts_engine.ve.to(device)
tts_engine.s3gen.to(device)
logger.info(f"Processing dataset... Total: {len(data)}")
success_count = 0
SPEECH_STOP_ID = getattr(tts_engine.t3.hp, 'stop_speech_token', 6562)
for idx, row in tqdm(data.iterrows(), total=len(data)):
try:
filename = str(row[0])
if not filename.endswith(".wav"):
filename += ".wav"
wav_path = os.path.join(config.wav_dir, filename)
if not os.path.exists(wav_path):
continue
wav, sr = torchaudio.load(wav_path)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != S3_SR:
resampler = torchaudio.transforms.Resample(sr, S3_SR)
wav = resampler(wav)
wav = wav.to(device)
with torch.no_grad():
wav_np = wav.cpu().squeeze().numpy()
spk_emb_np = tts_engine.ve.embeds_from_wavs([wav_np], sample_rate=S3_SR)
speaker_emb = torch.from_numpy(spk_emb_np[0]).cpu()
s_tokens, _ = tts_engine.s3gen.tokenizer(wav.unsqueeze(0))
raw_speech_tokens = s_tokens.squeeze().cpu()
stop_speech_tensor = torch.tensor([SPEECH_STOP_ID], dtype=raw_speech_tokens.dtype)
speech_tokens = torch.cat([raw_speech_tokens, stop_speech_tensor], dim=0)
prompt_samples = int(config.prompt_duration * S3_SR)
if wav.shape[1] < prompt_samples:
prompt_wav = torch.nn.functional.pad(wav, (0, prompt_samples - wav.shape[1]))
else:
prompt_wav = wav[:, :prompt_samples]
p_tokens, _ = tts_engine.s3gen.tokenizer(prompt_wav.unsqueeze(0))
prompt_tokens = p_tokens.squeeze().cpu()
raw_text = str(row[2]) if len(row) > 2 else str(row[1])
clean_text = punc_norm(raw_text)
# Tokenizer
if config.is_turbo:
token_output = tts_engine.tokenizer(clean_text, return_tensors="pt")
raw_text_tokens = token_output.input_ids[0].cpu()
if tts_engine.tokenizer.eos_token_id is not None:
text_eos = torch.tensor([tts_engine.tokenizer.eos_token_id], dtype=raw_text_tokens.dtype)
text_tokens = torch.cat([raw_text_tokens, text_eos], dim=0)
else:
text_tokens = raw_text_tokens
else:
text_tokens = tts_engine.tokenizer.text_to_tokens(clean_text).squeeze(0).cpu()
save_path = os.path.join(config.preprocessed_dir, filename.replace(".wav", ".pt"))
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save({
"speech_tokens": speech_tokens,
"speaker_emb": speaker_emb,
"prompt_tokens": prompt_tokens,
"text_tokens": text_tokens
}, save_path)
success_count += 1
except Exception as e:
logger.error(f"Error ({filename}): {e}")
continue
logger.info(f"Preprocessing completed! Success: {success_count}/{len(data)}") |