ASR / src /data /oom_proof_preprocess.py
MihirRPatil's picture
deploy: CDAC ASR backend with pitch/stress fix and LLM feedback
88a679b
Raw
History Blame Contribute Delete
20.4 kB
import os
import sys
import io
import re
import shutil
import argparse
import numpy as np
import soundfile as sf
import librosa
import torch
import nltk
from datasets import load_dataset, Audio, load_from_disk, concatenate_datasets, Dataset, DatasetDict
# Add the project root to sys.path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from transformers import Wav2Vec2Processor
from src.g2p.g2p_utils import G2PManager
from src.utils.audio_utils import AudioPreprocessor
PREPROCESSOR = None
PROCESSOR = None
G2P_MANAGER = None
def init_worker(processor_dir, dict_path):
global PREPROCESSOR, PROCESSOR, G2P_MANAGER
if PREPROCESSOR is None:
torch.set_num_threads(1)
PREPROCESSOR = AudioPreprocessor(sr=16000)
if PROCESSOR is None:
PROCESSOR = Wav2Vec2Processor.from_pretrained(processor_dir)
if G2P_MANAGER is None:
G2P_MANAGER = G2PManager(dict_path=dict_path)
def preprocess_batch(batch, processor_dir, dict_path):
init_worker(processor_dir, dict_path)
input_values_list = []
labels_list = []
audios = batch["audio"]
text_key = None
for key in ["text", "transcription", "sentence", "normalized_text"]:
if key in batch:
text_key = key
break
texts = batch[text_key] if text_key is not None else [""] * len(audios)
for i in range(len(audios)):
try:
audio_data = audios[i]
text = texts[i] if i < len(texts) else ""
if isinstance(audio_data, dict) and "bytes" in audio_data and audio_data["bytes"] is not None:
audio_array, sr = sf.read(io.BytesIO(audio_data["bytes"]))
elif isinstance(audio_data, dict) and "array" in audio_data and audio_data["array"] is not None:
audio_array = np.array(audio_data["array"])
sr = audio_data.get("sampling_rate", 16000)
elif isinstance(audio_data, dict) and "path" in audio_data and audio_data["path"] is not None:
audio_array, sr = sf.read(audio_data["path"])
else:
raise ValueError("Invalid audio format or missing audio content.")
if sr != 16000:
try:
from scipy.signal import resample_poly
import math
gcd = math.gcd(sr, 16000)
up = 16000 // gcd
down = sr // gcd
audio_array = resample_poly(audio_array, up, down)
except Exception:
audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=16000, res_type="kaiser_fast")
clean_audio = PREPROCESSOR.preprocess(audio_array)
if len(clean_audio) == 0:
raise ValueError("Audio clip is empty after FFT filtering and VAD silence trimming.")
input_values = PROCESSOR(clean_audio, sampling_rate=16000).input_values[0]
phonemes = G2P_MANAGER.convert_sentence(text)
if len(phonemes) == 0:
raise ValueError("Phoneme sequence is empty after G2P conversion.")
labels = PROCESSOR.tokenizer.convert_tokens_to_ids(phonemes)
input_values_list.append(input_values)
labels_list.append(labels)
except Exception:
pass
return {"input_values": input_values_list, "labels": labels_list}
def is_valid_english_script(text):
if not text:
return False
try:
text.encode('ascii')
return bool(re.search(r"[A-Za-z]", text))
except UnicodeEncodeError:
return False
_VOCAB_CACHE = None
def lexical_filter(text, g2p_manager, tokenizer):
global _VOCAB_CACHE
words = g2p_manager.tokenize(text)
if not words:
return False
if _VOCAB_CACHE is None:
_VOCAB_CACHE = tokenizer.get_vocab()
vocab = _VOCAB_CACHE
valid_words = 0
for word in words:
phonemes = g2p_manager.convert_word(word)
if len(phonemes) == 0:
continue
if all(p not in vocab for p in phonemes):
continue
valid_words += 1
return valid_words > 0
def build_and_apply_vocab_patch(dataset, processor, g2p_manager, patch_path):
print("Running G2P vocabulary verification check...")
unk_id = processor.tokenizer.unk_token_id or 1
new_patches = {}
words_to_check = set()
for sample in dataset:
source = sample.get("source_dataset", "nptel")
if source != "nptel":
text = sample.get("text") or sample.get("transcription") or sample.get("sentence") or ""
words_to_check.update(g2p_manager.tokenize(text))
print(f"Analyzing {len(words_to_check)} unique words from non-NPTEL datasets...")
vocab = processor.tokenizer.get_vocab()
for word in words_to_check:
phonemes = g2p_manager.convert_word(word)
if len(phonemes) == 0:
continue
ids = processor.tokenizer.convert_tokens_to_ids(phonemes)
if any(i == unk_id for i in ids):
cleaned_phonemes = []
for p in phonemes:
if p in vocab:
cleaned_phonemes.append(p)
else:
closest = "".join([char for char in p if char in vocab])
if closest:
cleaned_phonemes.append(closest)
if cleaned_phonemes:
new_patches[word] = cleaned_phonemes
if new_patches:
print(f"Writing {len(new_patches)} new vocabulary patches to {patch_path}...")
existing_patches = {}
if os.path.exists(patch_path):
with open(patch_path, "r", encoding="utf-8") as f:
for line in f:
parts = line.strip().split("\t")
if len(parts) >= 2:
existing_patches[parts[0]] = parts[1].split()
existing_patches.update(new_patches)
os.makedirs(os.path.dirname(patch_path), exist_ok=True)
with open(patch_path, "w", encoding="utf-8") as f:
for w, phs in sorted(existing_patches.items()):
f.write(f"{w}\t{' '.join(phs)}\n")
g2p_manager.phoneme_dict.update(new_patches)
print("✅ Vocabulary patch successfully updated and merged!")
else:
print("✓ No vocabulary patches needed. All words mapped successfully.")
def preprocess_and_save_dataset(ds, text_keys, source_label, save_path, processor_dir, dict_path, num_proc, batch_size, g2p_manager, processor):
"""Processes a single dataset end-to-end and saves it directly to disk (0-RAM footprint)"""
print(f"\n--- Processing {source_label} ---")
# 1. Cast column to prevent audio loading in RAM
ds = ds.cast_column("audio", Audio(decode=False))
# 2. Filter
def filter_fn(example):
text = ""
if text_keys:
for k in text_keys:
if example.get(k):
text = example[k]
break
if not text:
text = example.get("sentence") or example.get("text") or example.get("transcription") or example.get("normalized_text") or ""
text = str(text).strip()
return is_valid_english_script(text) and lexical_filter(text, g2p_manager, processor.tokenizer)
ds_filtered = ds.filter(filter_fn, desc=f"Filtering {source_label}")
# 3. Standardize structure
def map_fn(example):
text = ""
if text_keys:
for k in text_keys:
if example.get(k):
text = example[k]
break
if not text:
text = example.get("sentence") or example.get("text") or example.get("transcription") or example.get("normalized_text") or ""
return {
"audio": example["audio"],
"text": str(text).strip(),
"source_dataset": source_label
}
columns_to_remove = [col for col in ds_filtered.column_names if col not in ["audio", "text", "source_dataset"]]
ds_standardized = ds_filtered.map(map_fn, remove_columns=columns_to_remove, desc=f"Standardizing {source_label}")
# 4. Map to features (audio features + phoneme labels)
print(f"Running preprocessing map for {source_label} with {num_proc} processes...")
original_columns = ds_standardized.column_names
ds_preprocessed = ds_standardized.map(
preprocess_batch,
fn_kwargs={"processor_dir": processor_dir, "dict_path": dict_path},
batched=True,
batch_size=batch_size,
num_proc=num_proc,
remove_columns=original_columns,
desc=f"Extracting features for {source_label}"
)
# Save directly to disk
os.makedirs(os.path.dirname(save_path), exist_ok=True)
ds_preprocessed.save_to_disk(save_path)
print(f"✓ Successfully preprocessed and saved {len(ds_preprocessed)} samples to {save_path}")
return len(ds_preprocessed)
def main():
parser = argparse.ArgumentParser(description="OOM-proof Preprocessing Pipeline for CDAC ASR")
parser.add_argument("--processor_dir", default="models/processor_dir")
parser.add_argument("--dict_path", default="src/g2p/output_v2_detailed.dict")
parser.add_argument("--save_dir", default="/data/local_nptel_processed")
parser.add_argument("--local_openslr_dir", default="/data/local_openslr_104")
parser.add_argument("--parts_dir", default="/data/preprocessed_parts")
parser.add_argument("--num_proc", type=int, default=40)
parser.add_argument("--batch_size", type=int, default=250)
parser.add_argument("--hf_token", default=None)
args = parser.parse_args()
hf_token = args.hf_token or os.environ.get("HF_TOKEN")
if isinstance(hf_token, str) and hf_token.strip().lower() in ["none", ""]:
hf_token = None
print("Checking NLTK resources...")
for res in ['averaged_perceptron_tagger', 'averaged_perceptron_tagger_eng', 'cmudict']:
nltk.download(res, quiet=True)
print("Warming up Silero VAD cache...")
_ = AudioPreprocessor(sr=16000)
processor = Wav2Vec2Processor.from_pretrained(args.processor_dir)
g2p_manager = G2PManager(dict_path=args.dict_path)
# Dictionary of datasets to process
configs = [
("WillHeld/india_accent_cv", "train", ["sentence"], "common_voice", None),
("theothertom/indian_english_extended", "train", ["transcription", "sentence"], "theothertom_extended", None),
("theothertom/indian_english_bigger", "train", ["transcription", "sentence"], "theothertom_bigger", None),
("theothertom/indian_english_audio_2", "train", ["transcription", "sentence"], "theothertom_audio_2", None),
("ai4bharat/Svarah", "test", ["transcription"], "svarah", None),
("eka-care/medical-asr", "train", ["transcription", "text"], "eka_care", None)
]
parts_counts = {}
preprocessed_datasets = []
# 1. Process standard datasets one-by-one (0-RAM OOM protection)
for path, split, text_keys, label, conf in configs:
part_save_path = os.path.join(args.parts_dir, label)
# Check if already processed and saved on persistent storage (resume support!)
if os.path.exists(os.path.join(part_save_path, "dataset_info.json")):
print(f"✓ Part {label} already preprocessed on disk. Loading...")
parts_counts[label] = len(load_from_disk(part_save_path))
preprocessed_datasets.append(part_save_path)
continue
try:
print(f"\nLoading {path}...")
if conf:
ds = load_dataset(path, conf, split=split, token=hf_token)
else:
ds = load_dataset(path, split=split, token=hf_token)
if label == "eka_care":
ds = ds.filter(lambda x: not x.get("is_synthetic", False))
count = preprocess_and_save_dataset(
ds, text_keys, label, part_save_path,
args.processor_dir, args.dict_path, args.num_proc,
args.batch_size, g2p_manager, processor
)
parts_counts[label] = count
preprocessed_datasets.append(part_save_path)
except Exception as e:
print(f"⚠️ Error processing {label}: {e}")
# 2. Process OpenSLR 104
openslr_part_path = os.path.join(args.parts_dir, "openslr_104")
if os.path.exists(os.path.join(openslr_part_path, "dataset_info.json")):
print("✓ OpenSLR 104 already preprocessed. Loading...")
parts_counts["openslr_104"] = len(load_from_disk(openslr_part_path))
preprocessed_datasets.append(openslr_part_path)
else:
if os.path.exists(args.local_openslr_dir):
try:
print(f"Loading local OpenSLR 104 from {args.local_openslr_dir}...")
local_ds = load_from_disk(args.local_openslr_dir)
count = preprocess_and_save_dataset(
local_ds, ["transcription", "sentence", "text"], "openslr_104", openslr_part_path,
args.processor_dir, args.dict_path, args.num_proc,
args.batch_size, g2p_manager, processor
)
parts_counts["openslr_104"] = count
preprocessed_datasets.append(openslr_part_path)
except Exception as e:
print(f"⚠️ Error loading OpenSLR 104: {e}")
else:
print("⚠️ OpenSLR 104 local directory not found! Skipping OpenSLR.")
# 3. Sum other datasets to determine NPTEL balance count
n_others = sum(parts_counts.values())
print(f"\nTotal non-NPTEL samples processed: {n_others}")
# 4. Stream and process NPTEL in 5000-sample chunk shards (OOM-proof NPTEL preprocessing)
nptel_parts_dir = os.path.join(args.parts_dir, "nptel_chunks")
os.makedirs(nptel_parts_dir, exist_ok=True)
# Let's see how many NPTEL samples we have already processed
existing_nptel_parts = []
if os.path.exists(nptel_parts_dir):
existing_nptel_parts = [os.path.join(nptel_parts_dir, d) for d in os.listdir(nptel_parts_dir)
if os.path.exists(os.path.join(nptel_parts_dir, d, "dataset_info.json"))]
n_nptel_loaded = sum(len(load_from_disk(p)) for p in existing_nptel_parts)
print(f"Already preprocessed NPTEL samples found on disk: {n_nptel_loaded}/{n_others}")
if n_nptel_loaded >= n_others:
print("✓ NPTEL balancing dataset already fully preprocessed on disk.")
preprocessed_datasets.extend(existing_nptel_parts)
else:
print(f"Streaming remaining NPTEL data from HuggingFace to match {n_others} target...")
try:
nptel_ds = load_dataset("skbose/indian-english-nptel-v0", split="train", streaming=True, token=hf_token)
nptel_ds = nptel_ds.cast_column("audio", Audio(decode=False))
chunk_size = 5000
current_chunk = []
chunk_idx = len(existing_nptel_parts)
loaded = n_nptel_loaded
checked = 0
# Skip records already gathered in previous run if resuming
skipped = 0
for sample in nptel_ds:
checked += 1
if checked % 1000 == 0:
print(f" [NPTEL Stream] Checked {checked} stream records, matched {loaded + len(current_chunk)}/{n_others}...", flush=True)
text = sample.get("text") or sample.get("transcription") or ""
text = str(text).strip()
if is_valid_english_script(text) and lexical_filter(text, g2p_manager, processor.tokenizer):
if skipped < n_nptel_loaded:
skipped += 1
continue
current_chunk.append({
"audio": sample["audio"],
"text": text,
"source_dataset": "nptel"
})
if len(current_chunk) >= chunk_size or (loaded + len(current_chunk)) >= n_others:
# Process and save this chunk to disk
chunk_ds = Dataset.from_list(current_chunk)
chunk_save_path = os.path.join(nptel_parts_dir, f"chunk_{chunk_idx}")
original_columns = chunk_ds.column_names
print(f"\nProcessing NPTEL shard chunk {chunk_idx} ({len(chunk_ds)} samples)...")
chunk_ds_preprocessed = chunk_ds.map(
preprocess_batch,
fn_kwargs={"processor_dir": args.processor_dir, "dict_path": args.dict_path},
batched=True,
batch_size=args.batch_size,
num_proc=args.num_proc,
remove_columns=original_columns,
desc=f"Preprocessing NPTEL chunk {chunk_idx}"
)
chunk_ds_preprocessed.save_to_disk(chunk_save_path)
print(f"✓ Saved NPTEL chunk {chunk_idx} to {chunk_save_path}")
preprocessed_datasets.append(chunk_save_path)
loaded += len(current_chunk)
current_chunk = []
chunk_idx += 1
if loaded >= n_others:
break
# Process remaining items in buffer if any
if current_chunk and loaded < n_others:
chunk_ds = Dataset.from_list(current_chunk)
chunk_save_path = os.path.join(nptel_parts_dir, f"chunk_{chunk_idx}")
original_columns = chunk_ds.column_names
chunk_ds_preprocessed = chunk_ds.map(
preprocess_batch,
fn_kwargs={"processor_dir": args.processor_dir, "dict_path": args.dict_path},
batched=True,
batch_size=args.batch_size,
num_proc=args.num_proc,
remove_columns=original_columns,
desc=f"Preprocessing final NPTEL chunk"
)
chunk_ds_preprocessed.save_to_disk(chunk_save_path)
preprocessed_datasets.append(chunk_save_path)
loaded += len(current_chunk)
print(f"✓ Saved final NPTEL chunk to {chunk_save_path}")
print(f"✓ NPTEL preprocessing complete. Balanced with {loaded} NPTEL samples.")
except Exception as e:
print(f"⚠️ Error during NPTEL processing: {e}")
# 5. Concatenate all memory-mapped parts (0-RAM operation)
print("\n--- Final Dataset Assembly ---")
print(f"Loading all {len(preprocessed_datasets)} preprocessed partitions from disk...")
loaded_parts = [load_from_disk(p) for p in preprocessed_datasets]
print("Concatenating all parts...")
final_dataset = concatenate_datasets(loaded_parts)
print(f"✓ Concatenated. Total samples: {len(final_dataset)}")
print("Shuffling combined dataset out-of-core...")
final_dataset = final_dataset.shuffle(seed=42)
print("Splitting dataset into train and test splits (10% test)...")
dataset_dict = final_dataset.train_test_split(test_size=0.1, seed=42)
print(f"Saving final DatasetDict to disk at '{args.save_dir}'...")
dataset_dict.save_to_disk(args.save_dir)
print("✅ Preprocessing, train-test split, and save completed successfully!")
# 6. Clean up temporary parts
print(f"Cleaning up temporary part files in {args.parts_dir}...")
try:
shutil.rmtree(args.parts_dir)
print("✓ Cleaned up temporary parts directory.")
except Exception as e:
print(f"Warning: Cleanup failed: {e}")
if __name__ == "__main__":
main()