Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import argparse | |
| from pathlib import Path | |
| from typing import List, Dict, Set | |
| from tqdm import tqdm | |
| import soundfile as sf | |
| from datasets import load_dataset | |
| import logging | |
| import os | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Prepare the CapSpeech dataset") | |
| parser.add_argument('--hub', type=str, required=True, help='Huggingface repo') | |
| parser.add_argument('--save_dir', type=str, required=True, help='Directory to save the JSON files') | |
| parser.add_argument('--cache_dir', type=str, required=True, help='Cache directory for datasets') | |
| parser.add_argument('--libriR_wav_dir', type=str, required=True, help='Directories containing WAV files') | |
| parser.add_argument('--other_wav_dir', type=str, required=True, help='Directories containing WAV files') | |
| parser.add_argument('--audio_min_length', type=float, default=3.0, help='Minimum audio duration in seconds') | |
| parser.add_argument('--audio_max_length', type=float, default=18.0, help='Maximum audio duration in seconds') | |
| parser.add_argument('--splits', type=str, nargs='+', | |
| default=['train', 'val'], | |
| help='List of splits to process') | |
| parser.add_argument('--debug', action='store_true', help='Enable debug mode with limited data processing') | |
| return parser.parse_args() | |
| def setup_logging() -> None: | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s [%(levelname)s] %(message)s', | |
| handlers=[ | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| def process_dataset_split(split, dataset_split, args) -> List[Dict]: | |
| """ | |
| Process a single dataset split and extract relevant records. | |
| Args: | |
| split: The name of the split (e.g., 'train'). | |
| dataset_split: The dataset split object. | |
| args: Parsed command-line arguments. | |
| Returns: | |
| A list of dictionaries containing the processed records. | |
| """ | |
| logging.info(f"Processing split: {split}") | |
| filelist: List[Dict] = [] | |
| total_duration: float = 0.0 | |
| num_samples: int = len(dataset_split) if not args.debug else 500 | |
| source_path = { | |
| 'libritts-r': args.libriR_wav_dir, | |
| 'voxceleb': args.other_wav_dir, | |
| 'expresso': args.other_wav_dir, | |
| 'ears': args.other_wav_dir, | |
| 'vctk': args.other_wav_dir, | |
| } | |
| for idx in tqdm(range(num_samples), desc=f"Processing {split}"): | |
| try: | |
| data = dataset_split[idx] | |
| except IndexError: | |
| logging.warning(f"Index {idx} out of range for split '{split}'. Skipping.") | |
| continue | |
| audio_path: str = data.get("audio_path", "") | |
| duration: float = data.get("speech_duration", 0.0) | |
| source: str = data.get("source", "") | |
| audio_path = os.path.join(source_path[source], audio_path) | |
| if not audio_path: | |
| logging.warning(f"Missing audio_path at index {idx} in split '{split}'. Skipping.") | |
| continue | |
| if not os.path.exists(audio_path): | |
| logging.warning(f"WAV file does not exist: {audio_path}") | |
| continue | |
| if not (args.audio_min_length <= duration <= args.audio_max_length): | |
| continue | |
| record: Dict = { | |
| "segment_id": audio_path.split('/')[-1].split('.')[0], | |
| "audio_path": audio_path, | |
| "text": data.get('text', ''), | |
| "caption": data.get('caption', ''), | |
| "duration": duration, | |
| "source": source | |
| } | |
| filelist.append(record) | |
| total_duration += duration | |
| logging.info(f"Total duration for split '{split}': {total_duration / 3600:.2f} hrs.") | |
| logging.info(f"Total records for split '{split}': {len(filelist)}") | |
| return filelist | |
| def save_json(filelist: List[Dict], output_path: Path) -> None: | |
| """ | |
| Save the list of records to a JSON file. | |
| Args: | |
| filelist: List of dictionaries containing the records. | |
| output_path: Path to the output JSON file. | |
| """ | |
| try: | |
| with output_path.open('w', encoding='utf-8') as json_file: | |
| json.dump(filelist, json_file, ensure_ascii=False, indent=4) | |
| logging.info(f"Saved {len(filelist)} records to '{output_path}'") | |
| except Exception as e: | |
| logging.error(f"Failed to save JSON to '{output_path}': {e}") | |
| def main() -> None: | |
| args = parse_args() | |
| setup_logging() | |
| save_dir: Path = Path(args.save_dir) | |
| jsons_dir: Path = save_dir / 'jsons' | |
| jsons_dir.mkdir(parents=True, exist_ok=True) | |
| logging.info(f"JSON files will be saved to '{jsons_dir}'") | |
| logging.info("Loading dataset...") | |
| try: | |
| ds = load_dataset(args.hub) | |
| # ds = load_dataset(args.hub, cache_dir=args.cache_dir) | |
| except Exception as e: | |
| logging.error(f"Failed to load dataset: {e}") | |
| return | |
| splits_to_process = args.splits | |
| available_splits = set(ds.keys()) | |
| selected_splits = [split for split in splits_to_process if split in available_splits] | |
| missing_splits = set(splits_to_process) - available_splits | |
| if missing_splits: | |
| logging.warning(f"The following splits were not found in the dataset and will be skipped: {missing_splits}") | |
| for split in selected_splits: | |
| dataset_split = ds[split] | |
| filelist = process_dataset_split(split, dataset_split, args) | |
| output_file: Path = jsons_dir / f"{split}.json" | |
| save_json(filelist, output_file) | |
| if __name__ == "__main__": | |
| main() | |