File size: 11,591 Bytes
0eef6aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2be2dc3
0eef6aa
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
"""
Phonemes probabilities extraction
Audio wavs -> CUPE model -> phoneme probabilities 
"""

import torch
import torchaudio
import os
from tqdm import tqdm



from model2i import CUPEEmbeddingsExtractor  # main CUPE model's feature extractor
import windowing as windowing  # import slice_windows, stich_window_predictions

class EmbeddingsExtractionPipeline:
    """
    Pipeline for extracting allophone probabilities from audio using CUPE model
    """
    def __init__(self, cupe_ckpt_path, max_duration=10, verbose=True, device="cpu"):
        """
        Initialize the pipeline
        
        Args:
            cupe_ckpt_path: Path to CUPE model checkpoint
            device: Device to run inference on
        """
        self.device = device
        self.verbose = verbose
        self.extractor = CUPEEmbeddingsExtractor(cupe_ckpt_path, device=self.device)

        

        

        self.config(max_duration=max_duration)
        if self.verbose:
            print("max_frames_per_clip:", self.max_frames_per_clip.item())

        dummy_wav = torch.zeros(1, self.max_wav_len, dtype=torch.float32, device='cpu')  # dummy waveform for config
        dummy_wav = dummy_wav.unsqueeze(0)  # add batch dimension
        dummy_logits, dummy_spectral_lens = self._process_audio_batch(audio_batch=dummy_wav, wav_lens=torch.tensor([dummy_wav.shape[2]], dtype=torch.long) )

        if self.verbose:
            print (f"Dummy logits shape: {dummy_logits.shape}, Dummy spectral lengths: {dummy_spectral_lens}")

        assert dummy_logits.shape[1] == self.max_frames_per_clip, f"Dummy logits shape mismatch: {dummy_logits.shape[1]} vs {self.max_frames_per_clip}"
        assert dummy_logits.shape[2] == self.output_dim, f"Dummy logits output dimension mismatch: {dummy_logits.shape[2]} vs {self.output_dim}"

        # resampler for audio preprocessing - recommended for all audio inputs even if they are already 16kHz for consistency
        self.resampler = torchaudio.transforms.Resample(
            self.sample_rate,
            lowpass_filter_width=64,
            rolloff=0.9475937167399596,
            resampling_method="sinc_interp_kaiser",
            beta=14.769656459379492,
        )

    def config(self, max_duration=10, window_size_ms=120, stride_ms=80, ):
        """
        Configure pipeline parameters
        
        Args:
            max_duration: Maximum duration of audio in seconds
            window_size_ms: Window size in milliseconds
            stride_ms: Stride size in milliseconds
        """
        self.sample_rate = 16000
        self.window_size_ms = window_size_ms
        self.stride_ms = stride_ms
        self.phoneme_classes = 66
        self.phoneme_groups = 16
        self.extract_phoneme_groups = True  # whether to extract phoneme groups (16 phoneme groups)
        self.extract_phoneme_individuals = True  # whether to extract phoneme individuals (66 phoneme classes)
        # if both of the above are True, then the output dimension is 66 + 16 = 82 (logits from both tasks concatenated)
        self.output_dim = 0
        if self.extract_phoneme_individuals: self.output_dim += self.phoneme_classes
        if self.extract_phoneme_groups: self.output_dim += self.phoneme_groups
        
        
        self.window_size_wav = int(window_size_ms * self.sample_rate / 1000)
        self.stride_size_wav = int(stride_ms * self.sample_rate / 1000)
        self.max_wav_len = max_duration * self.sample_rate  # max duration in seconds to samples
        # Get frames_per_window from model
        self.frames_per_window = self.extractor.model.update_frames_per_window(self.window_size_wav)[0]
        self.max_frames_per_clip = windowing.calc_spec_len_ext(torch.tensor([self.max_wav_len], dtype=torch.long),self.window_size_ms, self.stride_ms, self.sample_rate, frames_per_window=self.frames_per_window,disable_windowing=False,wav_len_max=self.max_wav_len)[0]
        self.frames_per_second = windowing.calc_spec_len_ext(torch.tensor([100000*self.sample_rate], dtype=torch.long),self.window_size_ms, self.stride_ms, self.sample_rate, frames_per_window=self.frames_per_window,disable_windowing=False,wav_len_max=self.max_wav_len)[0].float()/100000
        
        self.ms_per_frame = int(1000 / self.frames_per_second.item())
        if self.verbose:
            print(f"frames_per_window: {self.frames_per_window.item()}, Max frames per clip: {self.max_frames_per_clip.item()}")  # INFO: 10 frames per 120ms window, 620 frames for 10s clip
            print(f"Frames per second: {self.frames_per_second.item()}")  # INFO: 62.49995040893555 frames per second
        
            print(f"milliseconds per frame: {self.ms_per_frame}")  # INFO: 16 milliseconds per frame

        

        

    def _process_audio_batch(self, audio_batch, wav_lens):
        """
        Process a batch of audio to extract logits
        
        Args:
            audio_batch: Batch of audio waveforms
            wav_lens: Lengths of each audio in the batch
            
        Returns:
            logits_class: Combined windows predictions
            spectral_lens: Lengths of spectral features
        """
        # Window the audio
        windowed_audio = windowing.slice_windows(
            audio_batch.to(self.device), 
            self.sample_rate, 
            self.window_size_ms, 
            self.stride_ms
        )
        batch_size, num_windows, window_size = windowed_audio.shape
        windows_flat = windowed_audio.reshape(-1, window_size)
        
        # Get predictions
        _logits_class, _logits_group = self.extractor.predict(windows_flat, return_embeddings=False, groups_only=False)
        
        
        # concatenate logits_class and logits_group
        logits = torch.cat([_logits_class[:, :, :self.phoneme_classes], _logits_group[:, :, :self.phoneme_groups]], dim=2)
        assert(logits.shape[2]==self.output_dim)
        
        frames_per_window = logits.shape[1] # INFO: 10 frames per window
        assert frames_per_window == self.frames_per_window, f"Expected {self.frames_per_window} frames per window, got {frames_per_window}"
        


        # Reshape and stitch window predictions
        logits = logits.reshape(batch_size, num_windows, frames_per_window, -1)
        logits = windowing.stich_window_predictions(
            logits, 
            original_audio_length=audio_batch.size(2), 
            cnn_output_size=frames_per_window, 
            sample_rate=self.sample_rate, 
            window_size_ms=self.window_size_ms, 
            stride_ms=self.stride_ms
        )
        # batch_size, seq_len, num_classes = logits.shape

        assert logits.shape[1] == self.max_frames_per_clip, f"Phoneme logits shape mismatch: {logits.shape[1]} vs {self.max_frames_per_clip}"

            
        

        # Calculate spectral lengths
        spectral_lens = windowing.calc_spec_len_ext(
            wav_lens, 
            self.window_size_ms, 
            self.stride_ms,
            self.sample_rate, 
            frames_per_window=self.frames_per_window,
            disable_windowing=False,
            wav_len_max=self.max_wav_len
        )
        #frames_per_clip = logits_class.shape[1] # INFO: 620 frames for 10s clip
        

        assert max(spectral_lens) <= self.max_frames_per_clip, f"Max spectral length {max(spectral_lens)} exceeds {self.max_frames_per_clip}"
        assert min(spectral_lens) > 0, f"Min spectral length {min(spectral_lens)} is not valid"

        return logits, spectral_lens
    
    def extract_embeddings_dataloader(self, dataloader): # example code, dataloader not implemented
        
        
        print("Starting phoneme embeddings extraction process...")
        features_collation = None
        
        for batch_idx, batch_data in enumerate(tqdm(dataloader, desc="Extracting phonemes")):
        
            audio_batch, wav_lens, clip_id = batch_data
            
                
            # Process audio and get predictions
            class_probs, spectral_lens = self._process_audio_batch(audio_batch, wav_lens) # returns shape (batch_size, max_frames_per_clip, phoneme_classes)
            
            # Process each sequence in the batch manually 
            batch_size = spectral_lens.shape[0]
            

            # concat class_probs, frames_confidence, formants
        

            for i in range(batch_size):

                # Get sequence data
                if spectral_lens is not None:
                    class_probs_i = class_probs[i][:spectral_lens[i]]
                else:
                    class_probs_i = class_probs[i]
                
                
                features_i = class_probs_i.detach()
                self.output_handler(clip_id[i].item(), features_i)

        print(f"Extracted {len(features_collation)} allophone embeddings")
        if len(features_collation) == 0:
            raise ValueError("No valid phoneme features were extracted.")
        
        return features_collation
    

    def output_handler(self, clip_id, features): # callback function to handle the output of the pipeline - Not implemented
        """
        Handle the output of the pipeline
        
        Args:
            features_collation: Collated features from the extraction process
        """
        output_length = features.shape[0] # already de-padded
        if self.verbose:
            print(f"Output handler received {len(features)}-dim features of length {output_length} for clip {clip_id}")
        


def process_single_clip(path_to_audio, pipeline, unpad_output=True):
    """
    Process a single audio clip to extract phoneme embeddings
    
    Args:
        path_to_audio: Path to the audio file
        pipeline: EmbeddingsExtractionPipeline instance
    """
    audio_clip, sr = torchaudio.load(path_to_audio)
    
    if sr != pipeline.sample_rate:
        raise ValueError(f"Sample rate mismatch: {sr} vs {pipeline.sample_rate}")

    if audio_clip.shape[0] > 1:
        audio_clip = audio_clip.mean(dim=0, keepdim=True)  # Convert to mono if stereo

    audio_clip = audio_clip.to(pipeline.device)
    audio_clip = pipeline.resampler(audio_clip)  # Resample to 16kHz if needed
    audio_clip = audio_clip.unsqueeze(0)  # Add batch dimension

    if audio_clip.shape[2] > pipeline.max_wav_len:
        print(f"Audio clip {path_to_audio} exceeds max length {pipeline.max_wav_len}, trimming to max length.")
        audio_clip = audio_clip[:, :pipeline.max_wav_len]  # Trim to max length

    original_length = audio_clip.shape[2]
    if audio_clip.shape[2] < pipeline.max_wav_len:
        audio_clip = torch.nn.functional.pad(audio_clip, (0, pipeline.max_wav_len - audio_clip.shape[2]))

    features, output_length = pipeline._process_audio_batch(audio_batch=audio_clip, wav_lens=torch.tensor([original_length], dtype=torch.long))
    
    features = features.squeeze(0)  # Remove batch dimension

    if unpad_output: features = features[:output_length, :]

    print(f"Output shape: {features.shape} for audio clip {path_to_audio}")
    return features


if __name__ == "__main__":

    torch.manual_seed(42)

    cupe_ckpt_path = "ckpt/en_libri1000_uj01d_e199_val_GER=0.2307.ckpt"
    pipeline = EmbeddingsExtractionPipeline(cupe_ckpt_path, max_duration=10, device="cpu", verbose=False)
    
    audio_clip1_path = "samples/109867__timkahn__butterfly.wav.wav"
    audio_clip2_path = "samples/Schwa-What.mp3.wav"

    features1 = process_single_clip(audio_clip1_path, pipeline)
    features2 = process_single_clip(audio_clip2_path, pipeline)

    print("Done!")