File size: 12,065 Bytes
67dfc0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53f67e7
67dfc0f
 
53f67e7
 
 
 
67dfc0f
 
53f67e7
 
67dfc0f
53f67e7
 
67dfc0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53f67e7
 
67dfc0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53f67e7
67dfc0f
 
 
53f67e7
 
 
 
 
67dfc0f
 
53f67e7
 
 
 
67dfc0f
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
#!/usr/bin/env python3
"""
Nigerian TTS Data Preprocessor
==============================
Runs on HuggingFace FREE CPU to preprocess audio data.
Downloads datasets, encodes with WavTokenizer, saves to HF Hub.
"""

import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import numpy as np
import gradio as gr
import gc
from datasets import load_dataset, concatenate_datasets, Dataset
from huggingface_hub import login, hf_hub_download, HfApi
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# CONFIG
# ============================================================================

HF_TOKEN = os.environ.get("HF_TOKEN", "")
BASE_MODEL = "HuggingFaceTB/SmolLM2-360M"
WAVTOKENIZER_REPO = "novateur/WavTokenizer-medium-speech-75token"
WAVTOKENIZER_CONFIG = "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
WAVTOKENIZER_CHECKPOINT = "wavtokenizer_medium_speech_320_24k_v2.ckpt"

SAMPLE_RATE = 24000
AUDIO_VOCAB_SIZE = 4096
MAX_AUDIO_LENGTH = 20  # seconds

EXISTING_DATASET = "UbuntuFarms/nigerian-tts-preprocessed-v2"
OUTPUT_DATASET = "UbuntuFarms/nigerian-tts-preprocessed-v3"

DEVICE = "cpu"  # Free tier = CPU only

# Data sources to add - VERIFIED WORKING datasets
DATA_SOURCES = {
    "pidgin": [
        # WORKING: 5883 samples with audio and text
        {"name": "Pidgin ASR Combined", "id": "timniel/Pidgin_ASR_Dataset_Combined", "subset": None},
        # WORKING: 65 samples with audio and text
        {"name": "Nigerian Pidgin Speech", "id": "Rexe/nigerian-pidgin-speech", "subset": None},
    ],
    "english": [
        # WORKING: 498 samples Nigerian English
        {"name": "Nigerian English TTS", "id": "Donmonc/nigerian_english_tts", "subset": None},
    ],
    # Note: yoruba, hausa, igbo already have 50k+ samples in existing dataset
    # Only add more if needed
}

# Global models
WAVTOKENIZER = None
TOKENIZER = None

# ============================================================================
# MODEL LOADING
# ============================================================================

def load_models():
    """Load WavTokenizer and text tokenizer."""
    global WAVTOKENIZER, TOKENIZER

    print("Loading models on CPU...")

    # Text tokenizer
    from transformers import AutoTokenizer
    TOKENIZER = AutoTokenizer.from_pretrained(BASE_MODEL)
    if TOKENIZER.pad_token is None:
        TOKENIZER.pad_token = TOKENIZER.eos_token
    special_tokens = ["<|audio|>", "[hausa]", "[yoruba]", "[igbo]", "[pidgin]", "[english]"]
    TOKENIZER.add_special_tokens({"additional_special_tokens": special_tokens})
    print(f"Tokenizer loaded: {len(TOKENIZER)} tokens")

    # WavTokenizer
    config_path = hf_hub_download(WAVTOKENIZER_REPO, WAVTOKENIZER_CONFIG)
    checkpoint_path = hf_hub_download(WAVTOKENIZER_REPO, WAVTOKENIZER_CHECKPOINT)

    from outetts.wav_tokenizer.decoder import WavTokenizer
    WAVTOKENIZER = WavTokenizer.from_pretrained0802(config_path, checkpoint_path)
    WAVTOKENIZER = WAVTOKENIZER.to(DEVICE)
    WAVTOKENIZER.eval()
    print("WavTokenizer loaded on CPU")

    return "Models loaded successfully!"

def encode_audio(audio_array, sample_rate):
    """Encode audio to WavTokenizer codes."""
    import torchaudio.functional as F

    # Convert to tensor
    if isinstance(audio_array, np.ndarray):
        audio_tensor = torch.from_numpy(audio_array).float()
    else:
        audio_tensor = audio_array.float()

    # Ensure mono
    if audio_tensor.dim() == 2:
        audio_tensor = audio_tensor.mean(dim=0)

    # Resample to 24kHz
    if sample_rate != SAMPLE_RATE:
        audio_tensor = F.resample(audio_tensor, sample_rate, SAMPLE_RATE)

    # Normalize
    audio_tensor = audio_tensor / (torch.max(torch.abs(audio_tensor)) + 1e-8)

    # Encode
    audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        _, codes = WAVTOKENIZER.encode_infer(audio_tensor, bandwidth_id=torch.tensor([0], device=DEVICE))

    return codes.squeeze().cpu().tolist()

# ============================================================================
# PREPROCESSING
# ============================================================================

def preprocess_dataset(source_info, language, max_samples, progress=gr.Progress()):
    """Download and preprocess a single dataset."""
    global TOKENIZER

    if WAVTOKENIZER is None:
        return None, "Please load models first!"

    text_vocab_size = len(TOKENIZER)
    processed = []

    try:
        progress(0, desc=f"Loading {source_info['name']}...")

        # Load dataset
        if source_info.get("subset"):
            ds = load_dataset(source_info["id"], source_info["subset"], split="train", trust_remote_code=True)
        else:
            ds = load_dataset(source_info["id"], split="train", trust_remote_code=True)

        total = min(len(ds), max_samples)
        if total < len(ds):
            ds = ds.shuffle(seed=42).select(range(total))

        # Find columns
        audio_col = next((c for c in ds.column_names if "audio" in c.lower()), None)
        text_col = next((c for c in ds.column_names if c in ["text", "sentence", "transcription", "transcript"]), None)

        if not audio_col or not text_col:
            return None, f"Could not find audio/text columns in {ds.column_names}"

        # Process
        for i, item in enumerate(ds):
            if i % 10 == 0:
                progress(i / total, desc=f"Processing {i}/{total}...")

            try:
                audio_data = item[audio_col]
                if isinstance(audio_data, dict):
                    audio_array = audio_data["array"]
                    sr = audio_data["sampling_rate"]
                else:
                    continue

                text = item[text_col]
                if not text or len(text.strip()) < 2:
                    continue

                # Check duration
                duration = len(audio_array) / sr
                if duration < 0.5 or duration > MAX_AUDIO_LENGTH:
                    continue

                # Encode audio
                audio_codes = encode_audio(audio_array, sr)
                if len(audio_codes) < 10:
                    continue

                # Build input_ids
                prompt = f"[{language}] {text.strip()} <|audio|>"
                text_ids = TOKENIZER.encode(prompt, add_special_tokens=False)
                audio_ids = [code + text_vocab_size for code in audio_codes]
                input_ids = text_ids + audio_ids

                processed.append({
                    "input_ids": input_ids,
                    "language": language,
                })

            except Exception as e:
                continue

            # Memory cleanup every 100 samples
            if i % 100 == 0:
                gc.collect()

        progress(1.0, desc="Done!")
        return processed, f"Processed {len(processed)} samples for {language}"

    except Exception as e:
        return None, f"Error: {str(e)}"

def run_full_preprocessing(languages, max_per_source, hf_token, progress=gr.Progress()):
    """Run preprocessing for selected languages and push to Hub."""
    global HF_TOKEN

    if not hf_token:
        return "Please provide HuggingFace token!"

    HF_TOKEN = hf_token
    login(token=HF_TOKEN)

    # Load models if needed
    if WAVTOKENIZER is None:
        load_models()

    all_samples = []
    status_log = []

    # Process each selected language
    selected_langs = [l.strip().lower() for l in languages.split(",")]

    for lang in selected_langs:
        if lang not in DATA_SOURCES:
            status_log.append(f"Unknown language: {lang}")
            continue

        for source in DATA_SOURCES[lang]:
            progress(0, desc=f"Processing {source['name']}...")
            samples, msg = preprocess_dataset(source, lang, max_per_source, progress)
            status_log.append(msg)

            if samples:
                all_samples.extend(samples)

            gc.collect()

    if not all_samples:
        return "\n".join(status_log) + "\n\nNo samples processed!"

    status_log.append(f"\nTotal new samples: {len(all_samples)}")

    # Load existing and combine
    progress(0.9, desc="Combining with existing dataset...")
    try:
        existing = load_dataset(EXISTING_DATASET, split="train")
        existing_slim = existing.select_columns(["input_ids", "language"])

        new_ds = Dataset.from_list(all_samples)
        combined = concatenate_datasets([existing_slim, new_ds])

        status_log.append(f"Combined: {len(combined)} total samples")

        # Count languages
        lang_counts = {}
        for item in combined:
            l = item.get("language", "unknown")
            lang_counts[l] = lang_counts.get(l, 0) + 1

        status_log.append("\nLanguage distribution:")
        for l, c in sorted(lang_counts.items(), key=lambda x: -x[1]):
            status_log.append(f"  {l}: {c:,}")

        # Push to Hub
        progress(0.95, desc="Pushing to HuggingFace Hub...")
        combined.push_to_hub(OUTPUT_DATASET, token=HF_TOKEN)
        status_log.append(f"\nDataset saved: https://huggingface.co/datasets/{OUTPUT_DATASET}")

    except Exception as e:
        status_log.append(f"Error combining/pushing: {str(e)}")

    progress(1.0, desc="Complete!")
    return "\n".join(status_log)

# ============================================================================
# GRADIO UI
# ============================================================================

print("=" * 60)
print("NIGERIAN TTS DATA PREPROCESSOR")
print("Runs on FREE CPU - saves GPU costs!")
print("=" * 60)

with gr.Blocks(title="Nigerian TTS Preprocessor") as demo:
    gr.Markdown("# Nigerian TTS Data Preprocessor")
    gr.Markdown("Preprocess audio datasets on **FREE CPU** to save RunPod GPU costs.")

    with gr.Row():
        with gr.Column():
            hf_token = gr.Textbox(
                label="HuggingFace Token (write access)",
                type="password",
                placeholder="hf_..."
            )
            languages = gr.Textbox(
                label="Languages to process (comma-separated)",
                value="pidgin",
                placeholder="pidgin,english"
            )
            max_samples = gr.Slider(
                minimum=100,
                maximum=20000,
                value=5000,
                step=100,
                label="Max samples per source"
            )

            load_btn = gr.Button("1. Load Models", variant="secondary")
            run_btn = gr.Button("2. Run Preprocessing", variant="primary")

        with gr.Column():
            output = gr.Textbox(
                label="Status",
                lines=20,
                max_lines=30
            )

    gr.Markdown("""
    ## Instructions
    1. Enter your HuggingFace token (needs write access)
    2. Click "Load Models" to load WavTokenizer
    3. Set languages to "pidgin" (or "pidgin,english" for both)
    4. Click "Run Preprocessing" - this will take a while on CPU!
    5. Once done, train on RunPod using the new dataset

    ## Available Datasets (VERIFIED WORKING)
    - **Pidgin**: `timniel/Pidgin_ASR_Dataset_Combined` (5,883 samples)
    - **Pidgin**: `Rexe/nigerian-pidgin-speech` (65 samples)
    - **English**: `Donmonc/nigerian_english_tts` (498 Nigerian English samples)

    ## Current Status
    - Existing dataset: `UbuntuFarms/nigerian-tts-preprocessed-v2` (148k samples)
    - yoruba: 53,332 samples
    - igbo: 47,526 samples
    - hausa: 47,288 samples
    - **pidgin: 0 samples** (CRITICAL - this is why Pidgin produces white noise!)
    - Output: `UbuntuFarms/nigerian-tts-preprocessed-v3`
    """)

    load_btn.click(fn=load_models, outputs=output)
    run_btn.click(
        fn=run_full_preprocessing,
        inputs=[languages, max_samples, hf_token],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch()