File size: 4,525 Bytes
6b408d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""

Audio preprocessing for Wav2Vec2 model.



Handles conversion from audio arrays to model input tensors.

"""

import numpy as np
import torch
from transformers import Wav2Vec2Processor

from app.utils.constants import TARGET_SAMPLE_RATE
from app.utils.logger import get_logger

logger = get_logger(__name__)


class AudioPreprocessor:
    """

    Preprocessor for preparing audio data for Wav2Vec2 model.



    Converts numpy audio arrays into the tensor format expected

    by the Wav2Vec2ForSequenceClassification model.

    """

    def __init__(

        self,

        processor: Wav2Vec2Processor,

        device: str = "cpu",

    ) -> None:
        """

        Initialize AudioPreprocessor.



        Args:

            processor: Wav2Vec2Processor instance

            device: Target device for tensors (cpu/cuda)

        """
        self.processor = processor
        self.device = device
        self.sample_rate = TARGET_SAMPLE_RATE

    def validate_input(self, audio_array: np.ndarray) -> bool:
        """

        Validate audio array for processing.



        Args:

            audio_array: Input audio array



        Returns:

            True if valid



        Raises:

            ValueError: If validation fails

        """
        if not isinstance(audio_array, np.ndarray):
            raise ValueError(f"Expected numpy array, got {type(audio_array)}")

        if audio_array.ndim != 1:
            raise ValueError(f"Expected 1D array, got {audio_array.ndim}D")

        if len(audio_array) == 0:
            raise ValueError("Audio array is empty")

        if np.isnan(audio_array).any():
            raise ValueError("Audio array contains NaN values")

        if np.isinf(audio_array).any():
            raise ValueError("Audio array contains infinite values")

        return True

    def preprocess(

        self,

        audio_array: np.ndarray,

        return_attention_mask: bool = True,

    ) -> dict[str, torch.Tensor]:
        """

        Preprocess audio array for model inference.



        Args:

            audio_array: 1D numpy array of audio samples (16kHz, normalized)

            return_attention_mask: Whether to return attention mask



        Returns:

            Dictionary with input_values and optionally attention_mask

        """
        # Validate input
        self.validate_input(audio_array)

        # Ensure float32
        audio_array = audio_array.astype(np.float32)

        # Process through Wav2Vec2Processor
        inputs = self.processor(
            audio_array,
            sampling_rate=self.sample_rate,
            return_tensors="pt",
            padding=True,
            return_attention_mask=return_attention_mask,
        )

        # Move to target device
        inputs = {key: value.to(self.device) for key, value in inputs.items()}

        logger.debug(
            "Audio preprocessed for model",
            input_length=inputs["input_values"].shape[-1],
            device=self.device,
        )

        return inputs

    def preprocess_batch(

        self,

        audio_arrays: list[np.ndarray],

        return_attention_mask: bool = True,

    ) -> dict[str, torch.Tensor]:
        """

        Preprocess a batch of audio arrays.



        Args:

            audio_arrays: List of 1D numpy arrays

            return_attention_mask: Whether to return attention mask



        Returns:

            Dictionary with batched input_values and optionally attention_mask

        """
        # Validate all inputs
        for i, audio in enumerate(audio_arrays):
            try:
                self.validate_input(audio)
            except ValueError as e:
                raise ValueError(f"Invalid audio at index {i}: {e}") from e

        # Ensure float32
        audio_arrays = [audio.astype(np.float32) for audio in audio_arrays]

        # Process batch through Wav2Vec2Processor
        inputs = self.processor(
            audio_arrays,
            sampling_rate=self.sample_rate,
            return_tensors="pt",
            padding=True,
            return_attention_mask=return_attention_mask,
        )

        # Move to target device
        inputs = {key: value.to(self.device) for key, value in inputs.items()}

        logger.debug(
            "Batch preprocessed for model",
            batch_size=len(audio_arrays),
            device=self.device,
        )

        return inputs