Spaces:
Runtime error
Runtime error
File size: 5,140 Bytes
38a17ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import json
import torch
import torchaudio
from tqdm import tqdm
from src.chatterbox_.tts_turbo import ChatterboxTurboTTS
from src.chatterbox_.tts import ChatterboxTTS, punc_norm
from src.chatterbox_.models.s3tokenizer import S3_SR
from src.utils import setup_logger
from src.config import TrainConfig
logger = setup_logger(__name__)
def preprocess_dataset_json_based(config, tts_engine: ChatterboxTTS):
"""
Reads metadata from JSON file, processes audio-text pairs, and saves them as .pt.
Structure:
- JSON contains: id, text, formatted_text, etc.
- Audio files: {wav_dir}/{id}.wav
"""
os.makedirs(config.preprocessed_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tts_engine.ve.to(device)
tts_engine.s3gen.to(device)
tts_engine.ve.eval()
tts_engine.s3gen.eval()
if not os.path.exists(config.metadata_path):
logger.error(f"ERROR: Metadata file not found: '{config.metadata_path}'!")
return
with open(config.metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
if len(metadata) == 0:
logger.error(f"ERROR: No items found in metadata file!")
return
logger.info(f"Processing dataset... Found items in JSON: {len(metadata)}")
success_count = 0
SPEECH_STOP_ID = getattr(tts_engine.t3.hp, 'stop_speech_token', 6562)
for item in tqdm(metadata, desc="Preprocessing"):
try:
file_id = item.get("id")
raw_text = item.get("text", "")
if not file_id or not raw_text:
logger.warning(f"Skipping item with missing id or text")
continue
wav_path = os.path.join(config.wav_dir, f"{file_id}.wav")
if not os.path.exists(wav_path):
logger.warning(f"Audio file not found, skipping: {file_id}")
continue
wav, sr = torchaudio.load(wav_path)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != S3_SR:
resampler = torchaudio.transforms.Resample(sr, S3_SR)
wav = resampler(wav)
wav = wav.to(device)
with torch.no_grad():
wav_np = wav.cpu().squeeze().numpy()
spk_emb_np = tts_engine.ve.embeds_from_wavs([wav_np], sample_rate=S3_SR)
speaker_emb = torch.from_numpy(spk_emb_np[0]).cpu()
s_tokens, _ = tts_engine.s3gen.tokenizer(wav.unsqueeze(0))
raw_speech_tokens = s_tokens.squeeze().cpu()
stop_speech_tensor = torch.tensor([SPEECH_STOP_ID], dtype=raw_speech_tokens.dtype)
speech_tokens = torch.cat([raw_speech_tokens, stop_speech_tensor], dim=0)
prompt_samples = int(config.prompt_duration * S3_SR)
if wav.shape[1] < prompt_samples:
prompt_wav = torch.nn.functional.pad(wav, (0, prompt_samples - wav.shape[1]))
else:
prompt_wav = wav[:, :prompt_samples]
p_tokens, _ = tts_engine.s3gen.tokenizer(prompt_wav.unsqueeze(0))
prompt_tokens = p_tokens.squeeze().cpu()
clean_text = punc_norm(raw_text)
if config.is_turbo:
token_output = tts_engine.tokenizer(clean_text, return_tensors="pt")
raw_text_tokens = token_output.input_ids[0].cpu()
if tts_engine.tokenizer.eos_token_id is not None:
text_eos = torch.tensor([tts_engine.tokenizer.eos_token_id], dtype=raw_text_tokens.dtype)
text_tokens = torch.cat([raw_text_tokens, text_eos], dim=0)
else:
text_tokens = raw_text_tokens
else:
text_tokens = tts_engine.tokenizer.text_to_tokens(clean_text).squeeze(0).cpu()
save_path = os.path.join(config.preprocessed_dir, f"{file_id}.pt")
torch.save({
"speech_tokens": speech_tokens,
"speaker_emb": speaker_emb,
"prompt_tokens": prompt_tokens,
"text_tokens": text_tokens,
}, save_path)
success_count += 1
except Exception as e:
logger.error(f"Error ({item.get('id', 'unknown')}): {e}")
continue
logger.info(f"Preprocessing completed! Success: {success_count}/{len(metadata)}")
if __name__ == "__main__":
cfg = TrainConfig()
if cfg.is_turbo:
EngineClass = ChatterboxTurboTTS
else:
EngineClass = ChatterboxTTS
logger.info(f"{EngineClass} engine starting...")
tts_engine = EngineClass.from_local(cfg.model_dir, device="cpu")
preprocess_dataset_json_based(cfg, tts_engine) |