Spaces:
Runtime error
Runtime error
File size: 4,983 Bytes
38a17ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import os
import glob
import torch
import torchaudio
from tqdm import tqdm
from src.chatterbox_.tts_turbo import ChatterboxTurboTTS
from src.chatterbox_.tts import ChatterboxTTS, punc_norm
from src.chatterbox_.models.s3tokenizer import S3_SR
from src.utils import setup_logger
from src.config import TrainConfig
logger = setup_logger(__name__)
def preprocess_dataset_file_based(config, tts_engine: ChatterboxTTS):
"""
Reads .wav and .txt file pairs in a folder, processes them, and saves them as .pt.
Structure:
ID.wav (Audio)
ID.txt (Text)
"""
os.makedirs(config.preprocessed_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tts_engine.ve.to(device)
tts_engine.s3gen.to(device)
tts_engine.ve.eval()
tts_engine.s3gen.eval()
search_path = os.path.join(config.wav_dir, "*.wav")
wav_files = glob.glob(search_path)
if len(wav_files) == 0:
logger.error(f"ERROR: No .wav files found in folder '{config.wav_dir}'!")
return
logger.info(f"Processing dataset... Found audio file: {len(wav_files)}")
success_count = 0
SPEECH_STOP_ID = getattr(tts_engine.t3.hp, 'stop_speech_token', 6562)
for wav_path in tqdm(wav_files, desc="Preprocessing"):
try:
filename = os.path.basename(wav_path)
file_id = os.path.splitext(filename)[0]
txt_path = os.path.join(config.wav_dir, f"{file_id}.txt")
if not os.path.exists(txt_path):
logger.warning(f"Text file not found, skipping: {file_id}")
continue
with open(txt_path, "r", encoding="utf-8") as f:
raw_text = f.read().strip()
if not raw_text:
continue
wav, sr = torchaudio.load(wav_path)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != S3_SR:
resampler = torchaudio.transforms.Resample(sr, S3_SR)
wav = resampler(wav)
wav = wav.to(device)
with torch.no_grad():
wav_np = wav.cpu().squeeze().numpy()
spk_emb_np = tts_engine.ve.embeds_from_wavs([wav_np], sample_rate=S3_SR)
speaker_emb = torch.from_numpy(spk_emb_np[0]).cpu()
s_tokens, _ = tts_engine.s3gen.tokenizer(wav.unsqueeze(0))
raw_speech_tokens = s_tokens.squeeze().cpu()
stop_speech_tensor = torch.tensor([SPEECH_STOP_ID], dtype=raw_speech_tokens.dtype)
speech_tokens = torch.cat([raw_speech_tokens, stop_speech_tensor], dim=0)
prompt_samples = int(config.prompt_duration * S3_SR)
if wav.shape[1] < prompt_samples:
prompt_wav = torch.nn.functional.pad(wav, (0, prompt_samples - wav.shape[1]))
else:
prompt_wav = wav[:, :prompt_samples]
p_tokens, _ = tts_engine.s3gen.tokenizer(prompt_wav.unsqueeze(0))
prompt_tokens = p_tokens.squeeze().cpu()
clean_text = punc_norm(raw_text)
if config.is_turbo:
token_output = tts_engine.tokenizer(clean_text, return_tensors="pt")
raw_text_tokens = token_output.input_ids[0].cpu()
if tts_engine.tokenizer.eos_token_id is not None:
text_eos = torch.tensor([tts_engine.tokenizer.eos_token_id], dtype=raw_text_tokens.dtype)
text_tokens = torch.cat([raw_text_tokens, text_eos], dim=0)
else:
text_tokens = raw_text_tokens
else:
text_tokens = tts_engine.tokenizer.text_to_tokens(clean_text).squeeze(0).cpu()
# --- 5. SAVING ---
# We keep the file name: ID.pt
save_path = os.path.join(config.preprocessed_dir, f"{file_id}.pt")
torch.save({
"speech_tokens": speech_tokens,
"speaker_emb": speaker_emb,
"prompt_tokens": prompt_tokens,
"text_tokens": text_tokens,
}, save_path)
success_count += 1
except Exception as e:
logger.error(f"Error ({filename}): {e}")
continue
logger.info(f"Preprocessing completed! Success: {success_count}/{len(wav_files)}")
if __name__ == "__main__":
cfg = TrainConfig()
if cfg.is_turbo:
EngineClass = ChatterboxTurboTTS
else:
EngineClass = ChatterboxTTS
logger.info(f"{EngineClass} engine starting...")
tts_engine = EngineClass.from_local(cfg.model_dir, device="cpu")
preprocess_dataset_file_based(cfg, tts_engine) |