w2v-bert-2.0-gl / prepare_data.py
carlacastedo's picture
Upload folder using huggingface_hub
ce9eaf9 verified
Raw
History Blame Contribute Delete
9.11 kB
import os, re, json
import numpy as np
from datasets import load_dataset, Audio, concatenate_datasets, DatasetDict
from transformers import (
Wav2Vec2CTCTokenizer,
SeamlessM4TFeatureExtractor,
Wav2Vec2BertProcessor,
)
# -----------------------------
# Configuración
# -----------------------------
num_proc = 24
audio_dir = "/home/devbcp/Proyectos/00-DATASETS/ASR/CommonVoice-v23-GL/cv-corpus-23.0-2025-09-05/gl/clips/"
output_path = "/mnt/datos/wav2vec2_datasets"
# -----------------------------
# Funciones auxiliares
# -----------------------------
def path_to_audio(example):
example["audio"] = os.path.join(audio_dir, example["path"])
return example
chars_to_remove_regex = r"[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\»\«]"
def normalize_text(batch):
text = batch.get("text") or ""
batch["text"] = re.sub(chars_to_remove_regex, "", text.lower())
return batch
def filter_valid_audio(example, min_duration=1.0):
audio = example["audio"]
if audio is None or audio["array"] is None:
return False
arr = audio["array"]
sr = audio["sampling_rate"]
if len(arr) == 0 or sr <= 0:
return False
if np.isnan(arr).any() or np.isinf(arr).any():
return False
duration = len(arr) / sr
return duration >= min_duration
def prepare_dataset(batch):
audio = batch["audio"]
if audio is None or audio.get("array") is None or len(audio["array"]) == 0 or audio["sampling_rate"] <= 0:
return {"skip": True}
try:
feats = processor(
audio["array"], sampling_rate=audio["sampling_rate"]
).input_features[0]
except Exception:
return {"skip": True}
labels = tokenizer(text_target=(batch.get("text") or "")).input_ids
return {
"input_features": feats,
"input_length": len(feats),
"labels": labels,
"skip": False,
}
def map_prepare(ds, name, num_proc=8):
print(f"Procesando {name}...")
for split in ds.keys():
cols = ds[split].column_names
ds[split] = ds[split].map(
prepare_dataset,
remove_columns=cols,
num_proc=num_proc,
batched=False,
desc=f"{name}-{split}"
)
ds[split] = ds[split].filter(lambda x: not x.get("skip", False), num_proc=1)
return ds
def clean_text_and_audio(ds, num_proc_text, min_duration=1.0):
ds = ds.filter(lambda x: x["text"] is not None and x["text"].strip() != "", num_proc=1)
ds = ds.filter(lambda x: filter_valid_audio(x, min_duration=min_duration), num_proc=1)
ds = ds.map(normalize_text, num_proc=num_proc_text)
return ds
# -----------------------------
# Cargar datasets
# -----------------------------
common_voice = load_dataset("/home/devbcp/Proyectos/00-DATASETS/ASR/CommonVoice-v23-GL")
openslr = load_dataset("/home/devbcp/Proyectos/00-DATASETS/ASR/OpenSLR-SpeechT-GL-EN")
fleurs = load_dataset("/home/devbcp/Proyectos/00-DATASETS/ASR/FLEURS-SpeechT-GL-EN")
falai = load_dataset("/home/devbcp/Proyectos/00-DATASETS/ASR/FalAI")
transcrispeech = load_dataset("/home/devbcp/Proyectos/00-DATASETS/ASR/Transcrispeech-GL")
rg_podcast = load_dataset("/home/devbcp/Proyectos/00-DATASETS/ASR/RG-Podcast-GL")
# -----------------------------
# Normalizar columnas de texto
# -----------------------------
common_voice = common_voice.rename_column("sentence", "text")
openslr = openslr.rename_column("text_gl", "text")
fleurs = fleurs.rename_column("text_gl", "text")
falai = falai.rename_column("sentence", "text")
# transcrispeech y rg_podcast ya tienen "text"
# -----------------------------
# Crear columna audio y castear a 16kHz
# -----------------------------
common_voice = common_voice.map(path_to_audio, num_proc=num_proc)
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
openslr = openslr.cast_column("audio", Audio(sampling_rate=16000))
fleurs = fleurs.cast_column("audio", Audio(sampling_rate=16000))
falai = falai.cast_column("audio", Audio(sampling_rate=16000))
transcrispeech = transcrispeech.cast_column("audio", Audio(sampling_rate=16000))
rg_podcast = rg_podcast.cast_column("audio", Audio(sampling_rate=16000))
# -----------------------------
# FalAI: usar validated, muestrear 20%, y crear train/val/test
# -----------------------------
falai_validated = falai["validated"]
n = int(0.2 * len(falai_validated))
falai_sampled = falai_validated.shuffle(seed=42).select(range(n))
falai_split = falai_sampled.train_test_split(test_size=0.2, seed=42)
val_test = falai_split["test"].train_test_split(test_size=0.5, seed=42)
falai_reduced = DatasetDict({
"train": falai_split["train"],
"validation": val_test["train"],
"test": val_test["test"]
})
transcrispeech = DatasetDict({
"train": transcrispeech["train"],
"validation": transcrispeech["dev"],
"test": transcrispeech["test"]
})
# -----------------------------
# RG-Podcast: renombrar dev -> validation
# -----------------------------
rg_podcast = DatasetDict({
"train": rg_podcast["train"],
"validation": rg_podcast["dev"],
"test": rg_podcast["test"]
})
# -----------------------------
# Limpieza de texto + filtro de audios cortos
# -----------------------------
common_voice = clean_text_and_audio(common_voice, num_proc_text=num_proc, min_duration=1.0)
openslr = clean_text_and_audio(openslr, num_proc_text=num_proc, min_duration=1.0)
fleurs = clean_text_and_audio(fleurs, num_proc_text=num_proc, min_duration=1.0)
falai_reduced = clean_text_and_audio(falai_reduced, num_proc_text=num_proc, min_duration=1.0)
transcrispeech = clean_text_and_audio(transcrispeech, num_proc_text=num_proc, min_duration=1.0)
rg_podcast = clean_text_and_audio(rg_podcast, num_proc_text=num_proc, min_duration=1.0)
# -----------------------------
# Construcción del vocabulario (versión limpia y correcta)
# -----------------------------
from collections import Counter
import unicodedata
def clean_char(c):
c = unicodedata.normalize("NFKC", c)
valid_chars = "abcdefghijklmnopqrstuvwxyzáéíóúñç "
if c in valid_chars:
return c
return None
all_texts = []
for ds in [common_voice, openslr, fleurs, falai_reduced, transcrispeech, rg_podcast]:
for split in ds.keys():
all_texts.extend(ds[split]["text"])
all_texts = [t.replace("\u00A0", " ").strip() for t in all_texts]
counter = Counter()
for t in all_texts:
for c in t.lower():
c = clean_char(c)
if c:
counter[c] += 1
vocab_list = sorted(counter.keys())
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
with open("vocab.json", "w") as f:
json.dump(vocab_dict, f, ensure_ascii=False, indent=2)
print("Vocabulario limpio generado con", len(vocab_dict), "tokens")
# -----------------------------
# Tokenizer & Processor
# -----------------------------
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
"./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
)
feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
processor = Wav2Vec2BertProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
# -----------------------------
# Preparar datasets con features
# -----------------------------
common_voice = map_prepare(common_voice, "CV", num_proc)
openslr = map_prepare(openslr, "SLR", num_proc)
fleurs = map_prepare(fleurs, "FLEURS", num_proc)
falai_reduced = map_prepare(falai_reduced, "FalAI", num_proc)
transcrispeech = map_prepare(transcrispeech, "TC", num_proc)
rg_podcast = map_prepare(rg_podcast, "POD", num_proc)
# -----------------------------
# Concatenación final
# -----------------------------
train_dataset = concatenate_datasets([
common_voice["train"], openslr["train"], fleurs["train"],
falai_reduced["train"], transcrispeech["train"], rg_podcast["train"]
])
valid_dataset = concatenate_datasets([
common_voice["validation"], openslr["validation"], fleurs["validation"],
falai_reduced["validation"], transcrispeech["validation"], rg_podcast["validation"]
])
test_dataset = concatenate_datasets([
common_voice["test"], openslr["test"], fleurs["test"],
falai_reduced["test"], transcrispeech["test"], rg_podcast["test"]
])
galician_dataset = DatasetDict({
"train": train_dataset,
"validation": valid_dataset,
"test": test_dataset
})
# -----------------------------
# Guardar dataset final
# -----------------------------
total = len(train_dataset) + len(valid_dataset) + len(test_dataset)
output_file = f"{output_path}/galician_dataset_w2vbert_complete_{total}"
galician_dataset.save_to_disk(output_file)
print("\nDataset final guardado correctamente:")
print("Train:", len(train_dataset))
print("Validation:", len(valid_dataset))
print("Test:", len(test_dataset))
print("Ruta:", output_file)