File size: 4,602 Bytes
5f2f308 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
Dataset implementation for LibriSpeech ASR.
"""
import os
import glob
from pathlib import Path
from typing import List, Optional, Tuple, Dict, Any
import torch
from datasets import Dataset, Audio
from sklearn.model_selection import train_test_split
def load_dataset(
root_dir: str,
processor: Any,
sample_cap: Optional[int] = None,
val_ratio: float = 0.2
) -> Tuple[Dataset, Dataset]:
"""加载并准备数据集
Args:
root_dir: 数据集根目录
processor: Whisper处理器
sample_cap: 样本数量上限
val_ratio: 验证集比例
Returns:
训练集和验证集的元组
"""
dataset = LibriSpeechDataset(
root_dir=root_dir,
sample_cap=sample_cap,
val_ratio=val_ratio
)
train_ds, val_ds, _ = dataset.prepare_datasets()
return train_ds, val_ds
class LibriSpeechDataset:
def __init__(
self,
root_dir: str,
sample_rate: int = 16000,
sample_cap: Optional[int] = None,
val_ratio: float = 0.2
):
self.root_dir = root_dir
self.sample_rate = sample_rate
self.sample_cap = sample_cap
self.val_ratio = val_ratio
self.train_ds = None
self.val_ds = None
self.test_ds = None
def load_split(self, splits: list, cap: int = None) -> Dataset:
audio_paths, transcripts = [], []
for split in splits:
split_dir = os.path.join(self.root_dir, split)
if not os.path.isdir(split_dir):
print(f"Warning: missing {split_dir}, skipping")
continue
for flac_path in glob.glob(f"{split_dir}/**/*.flac", recursive=True):
stem = Path(flac_path).stem
for txt in glob.glob(f"{os.path.dirname(flac_path)}/*.trans.txt"):
with open(txt, "r", encoding="utf-8") as f:
for line in f:
if line.startswith(stem):
text = line.strip().split(" ", 1)[1]
audio_paths.append(flac_path)
transcripts.append(text)
break
if len(audio_paths) >= (cap or float("inf")):
break
if cap and len(audio_paths) >= cap:
break
if cap and len(audio_paths) >= cap:
break
if not audio_paths:
raise ValueError(f"No audio files found under {self.root_dir} for {splits}")
print(f"Found {len(audio_paths)} audio files")
ds = Dataset.from_dict({"audio": audio_paths, "transcription": transcripts})
return ds.cast_column("audio", Audio(sampling_rate=self.sample_rate))
def prepare_datasets(self) -> Tuple[Dataset, Dataset, Dataset]:
"""准备训练、验证和测试数据集"""
# 加载数据
train_val_ds = self.load_split(["train-clean-100", "dev-clean"], cap=self.sample_cap)
test_ds = self.load_split(["test-clean"], cap=None)
# 分割训练和验证集
idx = list(range(len(train_val_ds)))
train_idx, val_idx = train_test_split(
idx,
test_size=self.val_ratio,
random_state=42,
shuffle=True
)
self.train_ds = train_val_ds.select(train_idx)
self.val_ds = train_val_ds.select(val_idx)
self.test_ds = test_ds
print(f"Samples → train: {len(self.train_ds)}, "
f"val: {len(self.val_ds)}, test: {len(self.test_ds)}")
return self.train_ds, self.val_ds, self.test_ds
@staticmethod
def summarize_real_durations(ds, name, sr=16_000):
if not isinstance(ds, Dataset):
raise RuntimeError(f"{name}_ds 不是 HuggingFace Dataset,请重新运行 Cell 1")
ds2 = ds.cast_column("audio", Audio(sampling_rate=sr))
ds2 = ds2.map(
lambda batch: {
"duration": [
len(item["array"]) / sr
for item in batch["audio"]
]
},
batched=True,
batch_size=32,
num_proc=4,
remove_columns=["audio"],
load_from_cache_file=False,
)
durations = ds2["duration"]
total_h = sum(durations) / 3600.0
avg_s = sum(durations) / len(durations)
print(f"{name:5s} | Samples: {len(ds2):4d} | Avg: {avg_s:5.1f}s | Total: {total_h:5.1f}h") |