File size: 4,169 Bytes
308155b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import torch
import torchaudio
import pandas as pd
from tqdm import tqdm
from src.chatterbox_.tts import ChatterboxTTS, punc_norm
from src.chatterbox_.models.s3tokenizer import S3_SR
from src.utils import setup_logger



logger = setup_logger(__name__)

def preprocess_dataset_ljspeech(config, tts_engine: ChatterboxTTS):
    
    data = pd.read_csv(config.csv_path, sep="|", header=None, quoting=3)
    
    os.makedirs(config.preprocessed_dir, exist_ok=True)


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    tts_engine.ve.to(device)
    tts_engine.s3gen.to(device)
    
    logger.info(f"Processing dataset... Total: {len(data)}")
    
    success_count = 0
    
    SPEECH_STOP_ID = getattr(tts_engine.t3.hp, 'stop_speech_token', 6562)
    for idx, row in tqdm(data.iterrows(), total=len(data)):
        
        try:
            
            filename = str(row[0])
            if not filename.endswith(".wav"): 
                filename += ".wav"
            
            wav_path = os.path.join(config.wav_dir, filename)
            
            if not os.path.exists(wav_path): 
                continue


            wav, sr = torchaudio.load(wav_path)
            
            if wav.shape[0] > 1: 
                wav = wav.mean(dim=0, keepdim=True)
                
            if sr != S3_SR:
                resampler = torchaudio.transforms.Resample(sr, S3_SR)
                wav = resampler(wav)
            
            wav = wav.to(device)


            with torch.no_grad():

                wav_np = wav.cpu().squeeze().numpy()
                
                spk_emb_np = tts_engine.ve.embeds_from_wavs([wav_np], sample_rate=S3_SR)
                speaker_emb = torch.from_numpy(spk_emb_np[0]).cpu()


                s_tokens, _ = tts_engine.s3gen.tokenizer(wav.unsqueeze(0))
                raw_speech_tokens = s_tokens.squeeze().cpu()
                stop_speech_tensor = torch.tensor([SPEECH_STOP_ID], dtype=raw_speech_tokens.dtype)
                speech_tokens = torch.cat([raw_speech_tokens, stop_speech_tensor], dim=0)


                prompt_samples = int(config.prompt_duration * S3_SR)
                
                if wav.shape[1] < prompt_samples:
                    prompt_wav = torch.nn.functional.pad(wav, (0, prompt_samples - wav.shape[1]))
                    
                else:
                    prompt_wav = wav[:, :prompt_samples]
                
                p_tokens, _ = tts_engine.s3gen.tokenizer(prompt_wav.unsqueeze(0))
                prompt_tokens = p_tokens.squeeze().cpu()


            raw_text = str(row[2]) if len(row) > 2 else str(row[1])
            
            clean_text = punc_norm(raw_text)

            # Tokenizer
            if config.is_turbo:
                token_output = tts_engine.tokenizer(clean_text, return_tensors="pt")
                raw_text_tokens = token_output.input_ids[0].cpu()
                
                if tts_engine.tokenizer.eos_token_id is not None:
                    text_eos = torch.tensor([tts_engine.tokenizer.eos_token_id], dtype=raw_text_tokens.dtype)
                    text_tokens = torch.cat([raw_text_tokens, text_eos], dim=0)
                else:
                    text_tokens = raw_text_tokens
            
            else:
                text_tokens = tts_engine.tokenizer.text_to_tokens(clean_text).squeeze(0).cpu()


            save_path = os.path.join(config.preprocessed_dir, filename.replace(".wav", ".pt"))
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            
            torch.save({
                "speech_tokens": speech_tokens,
                "speaker_emb": speaker_emb,
                "prompt_tokens": prompt_tokens,
                "text_tokens": text_tokens
            }, save_path)
            
            success_count += 1
        
        except Exception as e:
            logger.error(f"Error ({filename}): {e}")
            continue
        
    logger.info(f"Preprocessing completed! Success: {success_count}/{len(data)}")