Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import io | |
| import re | |
| import shutil | |
| import argparse | |
| import numpy as np | |
| import soundfile as sf | |
| import librosa | |
| import torch | |
| import nltk | |
| from datasets import load_dataset, Audio, load_from_disk, concatenate_datasets, Dataset, DatasetDict | |
| # Add the project root to sys.path | |
| project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) | |
| if project_root not in sys.path: | |
| sys.path.insert(0, project_root) | |
| from transformers import Wav2Vec2Processor | |
| from src.g2p.g2p_utils import G2PManager | |
| from src.utils.audio_utils import AudioPreprocessor | |
| PREPROCESSOR = None | |
| PROCESSOR = None | |
| G2P_MANAGER = None | |
| def init_worker(processor_dir, dict_path): | |
| global PREPROCESSOR, PROCESSOR, G2P_MANAGER | |
| if PREPROCESSOR is None: | |
| torch.set_num_threads(1) | |
| PREPROCESSOR = AudioPreprocessor(sr=16000) | |
| if PROCESSOR is None: | |
| PROCESSOR = Wav2Vec2Processor.from_pretrained(processor_dir) | |
| if G2P_MANAGER is None: | |
| G2P_MANAGER = G2PManager(dict_path=dict_path) | |
| def preprocess_batch(batch, processor_dir, dict_path): | |
| init_worker(processor_dir, dict_path) | |
| input_values_list = [] | |
| labels_list = [] | |
| audios = batch["audio"] | |
| text_key = None | |
| for key in ["text", "transcription", "sentence", "normalized_text"]: | |
| if key in batch: | |
| text_key = key | |
| break | |
| texts = batch[text_key] if text_key is not None else [""] * len(audios) | |
| for i in range(len(audios)): | |
| try: | |
| audio_data = audios[i] | |
| text = texts[i] if i < len(texts) else "" | |
| if isinstance(audio_data, dict) and "bytes" in audio_data and audio_data["bytes"] is not None: | |
| audio_array, sr = sf.read(io.BytesIO(audio_data["bytes"])) | |
| elif isinstance(audio_data, dict) and "array" in audio_data and audio_data["array"] is not None: | |
| audio_array = np.array(audio_data["array"]) | |
| sr = audio_data.get("sampling_rate", 16000) | |
| elif isinstance(audio_data, dict) and "path" in audio_data and audio_data["path"] is not None: | |
| audio_array, sr = sf.read(audio_data["path"]) | |
| else: | |
| raise ValueError("Invalid audio format or missing audio content.") | |
| if sr != 16000: | |
| try: | |
| from scipy.signal import resample_poly | |
| import math | |
| gcd = math.gcd(sr, 16000) | |
| up = 16000 // gcd | |
| down = sr // gcd | |
| audio_array = resample_poly(audio_array, up, down) | |
| except Exception: | |
| audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=16000, res_type="kaiser_fast") | |
| clean_audio = PREPROCESSOR.preprocess(audio_array) | |
| if len(clean_audio) == 0: | |
| raise ValueError("Audio clip is empty after FFT filtering and VAD silence trimming.") | |
| input_values = PROCESSOR(clean_audio, sampling_rate=16000).input_values[0] | |
| phonemes = G2P_MANAGER.convert_sentence(text) | |
| if len(phonemes) == 0: | |
| raise ValueError("Phoneme sequence is empty after G2P conversion.") | |
| labels = PROCESSOR.tokenizer.convert_tokens_to_ids(phonemes) | |
| input_values_list.append(input_values) | |
| labels_list.append(labels) | |
| except Exception: | |
| pass | |
| return {"input_values": input_values_list, "labels": labels_list} | |
| def is_valid_english_script(text): | |
| if not text: | |
| return False | |
| try: | |
| text.encode('ascii') | |
| return bool(re.search(r"[A-Za-z]", text)) | |
| except UnicodeEncodeError: | |
| return False | |
| _VOCAB_CACHE = None | |
| def lexical_filter(text, g2p_manager, tokenizer): | |
| global _VOCAB_CACHE | |
| words = g2p_manager.tokenize(text) | |
| if not words: | |
| return False | |
| if _VOCAB_CACHE is None: | |
| _VOCAB_CACHE = tokenizer.get_vocab() | |
| vocab = _VOCAB_CACHE | |
| valid_words = 0 | |
| for word in words: | |
| phonemes = g2p_manager.convert_word(word) | |
| if len(phonemes) == 0: | |
| continue | |
| if all(p not in vocab for p in phonemes): | |
| continue | |
| valid_words += 1 | |
| return valid_words > 0 | |
| def build_and_apply_vocab_patch(dataset, processor, g2p_manager, patch_path): | |
| print("Running G2P vocabulary verification check...") | |
| unk_id = processor.tokenizer.unk_token_id or 1 | |
| new_patches = {} | |
| words_to_check = set() | |
| for sample in dataset: | |
| source = sample.get("source_dataset", "nptel") | |
| if source != "nptel": | |
| text = sample.get("text") or sample.get("transcription") or sample.get("sentence") or "" | |
| words_to_check.update(g2p_manager.tokenize(text)) | |
| print(f"Analyzing {len(words_to_check)} unique words from non-NPTEL datasets...") | |
| vocab = processor.tokenizer.get_vocab() | |
| for word in words_to_check: | |
| phonemes = g2p_manager.convert_word(word) | |
| if len(phonemes) == 0: | |
| continue | |
| ids = processor.tokenizer.convert_tokens_to_ids(phonemes) | |
| if any(i == unk_id for i in ids): | |
| cleaned_phonemes = [] | |
| for p in phonemes: | |
| if p in vocab: | |
| cleaned_phonemes.append(p) | |
| else: | |
| closest = "".join([char for char in p if char in vocab]) | |
| if closest: | |
| cleaned_phonemes.append(closest) | |
| if cleaned_phonemes: | |
| new_patches[word] = cleaned_phonemes | |
| if new_patches: | |
| print(f"Writing {len(new_patches)} new vocabulary patches to {patch_path}...") | |
| existing_patches = {} | |
| if os.path.exists(patch_path): | |
| with open(patch_path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| parts = line.strip().split("\t") | |
| if len(parts) >= 2: | |
| existing_patches[parts[0]] = parts[1].split() | |
| existing_patches.update(new_patches) | |
| os.makedirs(os.path.dirname(patch_path), exist_ok=True) | |
| with open(patch_path, "w", encoding="utf-8") as f: | |
| for w, phs in sorted(existing_patches.items()): | |
| f.write(f"{w}\t{' '.join(phs)}\n") | |
| g2p_manager.phoneme_dict.update(new_patches) | |
| print("✅ Vocabulary patch successfully updated and merged!") | |
| else: | |
| print("✓ No vocabulary patches needed. All words mapped successfully.") | |
| def preprocess_and_save_dataset(ds, text_keys, source_label, save_path, processor_dir, dict_path, num_proc, batch_size, g2p_manager, processor): | |
| """Processes a single dataset end-to-end and saves it directly to disk (0-RAM footprint)""" | |
| print(f"\n--- Processing {source_label} ---") | |
| # 1. Cast column to prevent audio loading in RAM | |
| ds = ds.cast_column("audio", Audio(decode=False)) | |
| # 2. Filter | |
| def filter_fn(example): | |
| text = "" | |
| if text_keys: | |
| for k in text_keys: | |
| if example.get(k): | |
| text = example[k] | |
| break | |
| if not text: | |
| text = example.get("sentence") or example.get("text") or example.get("transcription") or example.get("normalized_text") or "" | |
| text = str(text).strip() | |
| return is_valid_english_script(text) and lexical_filter(text, g2p_manager, processor.tokenizer) | |
| ds_filtered = ds.filter(filter_fn, desc=f"Filtering {source_label}") | |
| # 3. Standardize structure | |
| def map_fn(example): | |
| text = "" | |
| if text_keys: | |
| for k in text_keys: | |
| if example.get(k): | |
| text = example[k] | |
| break | |
| if not text: | |
| text = example.get("sentence") or example.get("text") or example.get("transcription") or example.get("normalized_text") or "" | |
| return { | |
| "audio": example["audio"], | |
| "text": str(text).strip(), | |
| "source_dataset": source_label | |
| } | |
| columns_to_remove = [col for col in ds_filtered.column_names if col not in ["audio", "text", "source_dataset"]] | |
| ds_standardized = ds_filtered.map(map_fn, remove_columns=columns_to_remove, desc=f"Standardizing {source_label}") | |
| # 4. Map to features (audio features + phoneme labels) | |
| print(f"Running preprocessing map for {source_label} with {num_proc} processes...") | |
| original_columns = ds_standardized.column_names | |
| ds_preprocessed = ds_standardized.map( | |
| preprocess_batch, | |
| fn_kwargs={"processor_dir": processor_dir, "dict_path": dict_path}, | |
| batched=True, | |
| batch_size=batch_size, | |
| num_proc=num_proc, | |
| remove_columns=original_columns, | |
| desc=f"Extracting features for {source_label}" | |
| ) | |
| # Save directly to disk | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| ds_preprocessed.save_to_disk(save_path) | |
| print(f"✓ Successfully preprocessed and saved {len(ds_preprocessed)} samples to {save_path}") | |
| return len(ds_preprocessed) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="OOM-proof Preprocessing Pipeline for CDAC ASR") | |
| parser.add_argument("--processor_dir", default="models/processor_dir") | |
| parser.add_argument("--dict_path", default="src/g2p/output_v2_detailed.dict") | |
| parser.add_argument("--save_dir", default="/data/local_nptel_processed") | |
| parser.add_argument("--local_openslr_dir", default="/data/local_openslr_104") | |
| parser.add_argument("--parts_dir", default="/data/preprocessed_parts") | |
| parser.add_argument("--num_proc", type=int, default=40) | |
| parser.add_argument("--batch_size", type=int, default=250) | |
| parser.add_argument("--hf_token", default=None) | |
| args = parser.parse_args() | |
| hf_token = args.hf_token or os.environ.get("HF_TOKEN") | |
| if isinstance(hf_token, str) and hf_token.strip().lower() in ["none", ""]: | |
| hf_token = None | |
| print("Checking NLTK resources...") | |
| for res in ['averaged_perceptron_tagger', 'averaged_perceptron_tagger_eng', 'cmudict']: | |
| nltk.download(res, quiet=True) | |
| print("Warming up Silero VAD cache...") | |
| _ = AudioPreprocessor(sr=16000) | |
| processor = Wav2Vec2Processor.from_pretrained(args.processor_dir) | |
| g2p_manager = G2PManager(dict_path=args.dict_path) | |
| # Dictionary of datasets to process | |
| configs = [ | |
| ("WillHeld/india_accent_cv", "train", ["sentence"], "common_voice", None), | |
| ("theothertom/indian_english_extended", "train", ["transcription", "sentence"], "theothertom_extended", None), | |
| ("theothertom/indian_english_bigger", "train", ["transcription", "sentence"], "theothertom_bigger", None), | |
| ("theothertom/indian_english_audio_2", "train", ["transcription", "sentence"], "theothertom_audio_2", None), | |
| ("ai4bharat/Svarah", "test", ["transcription"], "svarah", None), | |
| ("eka-care/medical-asr", "train", ["transcription", "text"], "eka_care", None) | |
| ] | |
| parts_counts = {} | |
| preprocessed_datasets = [] | |
| # 1. Process standard datasets one-by-one (0-RAM OOM protection) | |
| for path, split, text_keys, label, conf in configs: | |
| part_save_path = os.path.join(args.parts_dir, label) | |
| # Check if already processed and saved on persistent storage (resume support!) | |
| if os.path.exists(os.path.join(part_save_path, "dataset_info.json")): | |
| print(f"✓ Part {label} already preprocessed on disk. Loading...") | |
| parts_counts[label] = len(load_from_disk(part_save_path)) | |
| preprocessed_datasets.append(part_save_path) | |
| continue | |
| try: | |
| print(f"\nLoading {path}...") | |
| if conf: | |
| ds = load_dataset(path, conf, split=split, token=hf_token) | |
| else: | |
| ds = load_dataset(path, split=split, token=hf_token) | |
| if label == "eka_care": | |
| ds = ds.filter(lambda x: not x.get("is_synthetic", False)) | |
| count = preprocess_and_save_dataset( | |
| ds, text_keys, label, part_save_path, | |
| args.processor_dir, args.dict_path, args.num_proc, | |
| args.batch_size, g2p_manager, processor | |
| ) | |
| parts_counts[label] = count | |
| preprocessed_datasets.append(part_save_path) | |
| except Exception as e: | |
| print(f"⚠️ Error processing {label}: {e}") | |
| # 2. Process OpenSLR 104 | |
| openslr_part_path = os.path.join(args.parts_dir, "openslr_104") | |
| if os.path.exists(os.path.join(openslr_part_path, "dataset_info.json")): | |
| print("✓ OpenSLR 104 already preprocessed. Loading...") | |
| parts_counts["openslr_104"] = len(load_from_disk(openslr_part_path)) | |
| preprocessed_datasets.append(openslr_part_path) | |
| else: | |
| if os.path.exists(args.local_openslr_dir): | |
| try: | |
| print(f"Loading local OpenSLR 104 from {args.local_openslr_dir}...") | |
| local_ds = load_from_disk(args.local_openslr_dir) | |
| count = preprocess_and_save_dataset( | |
| local_ds, ["transcription", "sentence", "text"], "openslr_104", openslr_part_path, | |
| args.processor_dir, args.dict_path, args.num_proc, | |
| args.batch_size, g2p_manager, processor | |
| ) | |
| parts_counts["openslr_104"] = count | |
| preprocessed_datasets.append(openslr_part_path) | |
| except Exception as e: | |
| print(f"⚠️ Error loading OpenSLR 104: {e}") | |
| else: | |
| print("⚠️ OpenSLR 104 local directory not found! Skipping OpenSLR.") | |
| # 3. Sum other datasets to determine NPTEL balance count | |
| n_others = sum(parts_counts.values()) | |
| print(f"\nTotal non-NPTEL samples processed: {n_others}") | |
| # 4. Stream and process NPTEL in 5000-sample chunk shards (OOM-proof NPTEL preprocessing) | |
| nptel_parts_dir = os.path.join(args.parts_dir, "nptel_chunks") | |
| os.makedirs(nptel_parts_dir, exist_ok=True) | |
| # Let's see how many NPTEL samples we have already processed | |
| existing_nptel_parts = [] | |
| if os.path.exists(nptel_parts_dir): | |
| existing_nptel_parts = [os.path.join(nptel_parts_dir, d) for d in os.listdir(nptel_parts_dir) | |
| if os.path.exists(os.path.join(nptel_parts_dir, d, "dataset_info.json"))] | |
| n_nptel_loaded = sum(len(load_from_disk(p)) for p in existing_nptel_parts) | |
| print(f"Already preprocessed NPTEL samples found on disk: {n_nptel_loaded}/{n_others}") | |
| if n_nptel_loaded >= n_others: | |
| print("✓ NPTEL balancing dataset already fully preprocessed on disk.") | |
| preprocessed_datasets.extend(existing_nptel_parts) | |
| else: | |
| print(f"Streaming remaining NPTEL data from HuggingFace to match {n_others} target...") | |
| try: | |
| nptel_ds = load_dataset("skbose/indian-english-nptel-v0", split="train", streaming=True, token=hf_token) | |
| nptel_ds = nptel_ds.cast_column("audio", Audio(decode=False)) | |
| chunk_size = 5000 | |
| current_chunk = [] | |
| chunk_idx = len(existing_nptel_parts) | |
| loaded = n_nptel_loaded | |
| checked = 0 | |
| # Skip records already gathered in previous run if resuming | |
| skipped = 0 | |
| for sample in nptel_ds: | |
| checked += 1 | |
| if checked % 1000 == 0: | |
| print(f" [NPTEL Stream] Checked {checked} stream records, matched {loaded + len(current_chunk)}/{n_others}...", flush=True) | |
| text = sample.get("text") or sample.get("transcription") or "" | |
| text = str(text).strip() | |
| if is_valid_english_script(text) and lexical_filter(text, g2p_manager, processor.tokenizer): | |
| if skipped < n_nptel_loaded: | |
| skipped += 1 | |
| continue | |
| current_chunk.append({ | |
| "audio": sample["audio"], | |
| "text": text, | |
| "source_dataset": "nptel" | |
| }) | |
| if len(current_chunk) >= chunk_size or (loaded + len(current_chunk)) >= n_others: | |
| # Process and save this chunk to disk | |
| chunk_ds = Dataset.from_list(current_chunk) | |
| chunk_save_path = os.path.join(nptel_parts_dir, f"chunk_{chunk_idx}") | |
| original_columns = chunk_ds.column_names | |
| print(f"\nProcessing NPTEL shard chunk {chunk_idx} ({len(chunk_ds)} samples)...") | |
| chunk_ds_preprocessed = chunk_ds.map( | |
| preprocess_batch, | |
| fn_kwargs={"processor_dir": args.processor_dir, "dict_path": args.dict_path}, | |
| batched=True, | |
| batch_size=args.batch_size, | |
| num_proc=args.num_proc, | |
| remove_columns=original_columns, | |
| desc=f"Preprocessing NPTEL chunk {chunk_idx}" | |
| ) | |
| chunk_ds_preprocessed.save_to_disk(chunk_save_path) | |
| print(f"✓ Saved NPTEL chunk {chunk_idx} to {chunk_save_path}") | |
| preprocessed_datasets.append(chunk_save_path) | |
| loaded += len(current_chunk) | |
| current_chunk = [] | |
| chunk_idx += 1 | |
| if loaded >= n_others: | |
| break | |
| # Process remaining items in buffer if any | |
| if current_chunk and loaded < n_others: | |
| chunk_ds = Dataset.from_list(current_chunk) | |
| chunk_save_path = os.path.join(nptel_parts_dir, f"chunk_{chunk_idx}") | |
| original_columns = chunk_ds.column_names | |
| chunk_ds_preprocessed = chunk_ds.map( | |
| preprocess_batch, | |
| fn_kwargs={"processor_dir": args.processor_dir, "dict_path": args.dict_path}, | |
| batched=True, | |
| batch_size=args.batch_size, | |
| num_proc=args.num_proc, | |
| remove_columns=original_columns, | |
| desc=f"Preprocessing final NPTEL chunk" | |
| ) | |
| chunk_ds_preprocessed.save_to_disk(chunk_save_path) | |
| preprocessed_datasets.append(chunk_save_path) | |
| loaded += len(current_chunk) | |
| print(f"✓ Saved final NPTEL chunk to {chunk_save_path}") | |
| print(f"✓ NPTEL preprocessing complete. Balanced with {loaded} NPTEL samples.") | |
| except Exception as e: | |
| print(f"⚠️ Error during NPTEL processing: {e}") | |
| # 5. Concatenate all memory-mapped parts (0-RAM operation) | |
| print("\n--- Final Dataset Assembly ---") | |
| print(f"Loading all {len(preprocessed_datasets)} preprocessed partitions from disk...") | |
| loaded_parts = [load_from_disk(p) for p in preprocessed_datasets] | |
| print("Concatenating all parts...") | |
| final_dataset = concatenate_datasets(loaded_parts) | |
| print(f"✓ Concatenated. Total samples: {len(final_dataset)}") | |
| print("Shuffling combined dataset out-of-core...") | |
| final_dataset = final_dataset.shuffle(seed=42) | |
| print("Splitting dataset into train and test splits (10% test)...") | |
| dataset_dict = final_dataset.train_test_split(test_size=0.1, seed=42) | |
| print(f"Saving final DatasetDict to disk at '{args.save_dir}'...") | |
| dataset_dict.save_to_disk(args.save_dir) | |
| print("✅ Preprocessing, train-test split, and save completed successfully!") | |
| # 6. Clean up temporary parts | |
| print(f"Cleaning up temporary part files in {args.parts_dir}...") | |
| try: | |
| shutil.rmtree(args.parts_dir) | |
| print("✓ Cleaned up temporary parts directory.") | |
| except Exception as e: | |
| print(f"Warning: Cleanup failed: {e}") | |
| if __name__ == "__main__": | |
| main() | |