| """ |
| DataLoader implementation for ASR training. |
| """ |
|
|
| import torch |
| from torch.utils.data import DataLoader |
| from transformers import WhisperProcessor |
|
|
| class ASRDataLoader: |
| def __init__( |
| self, |
| processor: WhisperProcessor, |
| batch_size: int = 32, |
| num_workers: int = 2, |
| max_frames: int = 3000, |
| sample_rate: int = 16000, |
| pin_memory: bool = True, |
| persistent_workers: bool = True, |
| prefetch_factor: int = 2 |
| ): |
|
|
| self.processor = processor |
| self.batch_size = batch_size |
| self.num_workers = num_workers |
| self.max_frames = max_frames |
| self.sample_rate = sample_rate |
| self.pin_memory = pin_memory |
| self.persistent_workers = persistent_workers |
| self.prefetch_factor = prefetch_factor |
|
|
| def collate_fn(self, batch): |
| batch = [b for b in batch if b and b["transcription"].strip()] |
| if not batch: |
| return None |
|
|
| processed_features = [] |
| processed_labels = [] |
|
|
| for item in batch: |
| audio_array = item["audio"]["array"] |
| text = item["transcription"] |
|
|
| audio_inputs = self.processor.feature_extractor( |
| audio_array, |
| sampling_rate=self.sample_rate, |
| return_tensors="pt" |
| ) |
|
|
| text_inputs = self.processor.tokenizer( |
| text, |
| padding="max_length", |
| truncation=True, |
| max_length=256, |
| return_tensors="pt" |
| ) |
|
|
| feats = audio_inputs.input_features[0] |
| if feats.shape[-1] > self.max_frames: |
| feats = feats[..., :self.max_frames] |
| elif feats.shape[-1] < self.max_frames: |
| pad_size = self.max_frames - feats.shape[-1] |
| feats = torch.nn.functional.pad(feats, (0, pad_size)) |
|
|
| processed_features.append(feats) |
|
|
| labels = text_inputs.input_ids[0] |
| labels[labels == self.processor.tokenizer.pad_token_id] = -100 |
| processed_labels.append(labels) |
|
|
| batch_features = torch.stack(processed_features, dim=0) |
| batch_labels = torch.stack(processed_labels, dim=0) |
|
|
| return { |
| "input_features": batch_features, |
| "labels": batch_labels |
| } |
| |
| def safe_collate_fn(self, batch): |
| try: |
| result = self.collate_fn(batch) |
| if result is None: |
| return { |
| "input_features": torch.zeros(1, 80, self.max_frames), |
| "labels": torch.full((1, 256), -100, dtype=torch.long) |
| } |
| return result |
| except Exception as e: |
| print(f"Collate error: {e}") |
| return { |
| "input_features": torch.zeros(1, 80, self.max_frames), |
| "labels": torch.full((1, 256), -100, dtype=torch.long) |
| } |
| |
| def get_loader(self, dataset, shuffle=True, drop_last=True): |
| return DataLoader( |
| dataset, |
| batch_size=self.batch_size, |
| shuffle=shuffle, |
| num_workers=self.num_workers, |
| pin_memory=self.pin_memory, |
| collate_fn=self.safe_collate_fn, |
| persistent_workers=self.persistent_workers, |
| prefetch_factor=self.prefetch_factor, |
| drop_last=drop_last |
| ) |