File size: 3,406 Bytes
5f2f308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
DataLoader implementation for ASR training.
""" 

import torch
from torch.utils.data import DataLoader
from transformers import WhisperProcessor

class ASRDataLoader:
    def __init__(
        self,
        processor: WhisperProcessor,
        batch_size: int = 32,
        num_workers: int = 2,
        max_frames: int = 3000,
        sample_rate: int = 16000,
        pin_memory: bool = True,
        persistent_workers: bool = True,
        prefetch_factor: int = 2
    ):

        self.processor = processor
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.max_frames = max_frames
        self.sample_rate = sample_rate
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers
        self.prefetch_factor = prefetch_factor

    def collate_fn(self, batch):
        batch = [b for b in batch if b and b["transcription"].strip()]
        if not batch:
            return None

        processed_features = []
        processed_labels = []

        for item in batch:
            audio_array = item["audio"]["array"]
            text = item["transcription"]

            audio_inputs = self.processor.feature_extractor(
                audio_array,
                sampling_rate=self.sample_rate,
                return_tensors="pt"
            )

            text_inputs = self.processor.tokenizer(
                text,
                padding="max_length",
                truncation=True,
                max_length=256,
                return_tensors="pt"
            )

            feats = audio_inputs.input_features[0]
            if feats.shape[-1] > self.max_frames:
                feats = feats[..., :self.max_frames]
            elif feats.shape[-1] < self.max_frames:
                pad_size = self.max_frames - feats.shape[-1]
                feats = torch.nn.functional.pad(feats, (0, pad_size))

            processed_features.append(feats)

            labels = text_inputs.input_ids[0]
            labels[labels == self.processor.tokenizer.pad_token_id] = -100
            processed_labels.append(labels)

        batch_features = torch.stack(processed_features, dim=0)
        batch_labels = torch.stack(processed_labels, dim=0)

        return {
            "input_features": batch_features,
            "labels": batch_labels
        }
        
    def safe_collate_fn(self, batch):
        try:
            result = self.collate_fn(batch)
            if result is None:
                return {
                    "input_features": torch.zeros(1, 80, self.max_frames),
                    "labels": torch.full((1, 256), -100, dtype=torch.long)
                }
            return result
        except Exception as e:
            print(f"Collate error: {e}")
            return {
                "input_features": torch.zeros(1, 80, self.max_frames),
                "labels": torch.full((1, 256), -100, dtype=torch.long)
            }
        
    def get_loader(self, dataset, shuffle=True, drop_last=True):
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self.safe_collate_fn,
            persistent_workers=self.persistent_workers,
            prefetch_factor=self.prefetch_factor,
            drop_last=drop_last
        )