File size: 3,406 Bytes
5f2f308 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | """
DataLoader implementation for ASR training.
"""
import torch
from torch.utils.data import DataLoader
from transformers import WhisperProcessor
class ASRDataLoader:
def __init__(
self,
processor: WhisperProcessor,
batch_size: int = 32,
num_workers: int = 2,
max_frames: int = 3000,
sample_rate: int = 16000,
pin_memory: bool = True,
persistent_workers: bool = True,
prefetch_factor: int = 2
):
self.processor = processor
self.batch_size = batch_size
self.num_workers = num_workers
self.max_frames = max_frames
self.sample_rate = sample_rate
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.prefetch_factor = prefetch_factor
def collate_fn(self, batch):
batch = [b for b in batch if b and b["transcription"].strip()]
if not batch:
return None
processed_features = []
processed_labels = []
for item in batch:
audio_array = item["audio"]["array"]
text = item["transcription"]
audio_inputs = self.processor.feature_extractor(
audio_array,
sampling_rate=self.sample_rate,
return_tensors="pt"
)
text_inputs = self.processor.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=256,
return_tensors="pt"
)
feats = audio_inputs.input_features[0]
if feats.shape[-1] > self.max_frames:
feats = feats[..., :self.max_frames]
elif feats.shape[-1] < self.max_frames:
pad_size = self.max_frames - feats.shape[-1]
feats = torch.nn.functional.pad(feats, (0, pad_size))
processed_features.append(feats)
labels = text_inputs.input_ids[0]
labels[labels == self.processor.tokenizer.pad_token_id] = -100
processed_labels.append(labels)
batch_features = torch.stack(processed_features, dim=0)
batch_labels = torch.stack(processed_labels, dim=0)
return {
"input_features": batch_features,
"labels": batch_labels
}
def safe_collate_fn(self, batch):
try:
result = self.collate_fn(batch)
if result is None:
return {
"input_features": torch.zeros(1, 80, self.max_frames),
"labels": torch.full((1, 256), -100, dtype=torch.long)
}
return result
except Exception as e:
print(f"Collate error: {e}")
return {
"input_features": torch.zeros(1, 80, self.max_frames),
"labels": torch.full((1, 256), -100, dtype=torch.long)
}
def get_loader(self, dataset, shuffle=True, drop_last=True):
return DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=shuffle,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=self.safe_collate_fn,
persistent_workers=self.persistent_workers,
prefetch_factor=self.prefetch_factor,
drop_last=drop_last
) |