| import os |
| import random |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from torch.nn.utils.rnn import pad_sequence |
| import pandas as pd |
| import sentencepiece as spm |
|
|
|
|
| import torchaudio |
| from torchaudio.transforms import Resample |
|
|
| |
| |
| |
| sp = spm.SentencePieceProcessor() |
| sp.Load("./ressources/tokenizer/128_v7.model") |
|
|
| |
| |
| |
| train_data = pd.read_csv("./ressources/train.csv", low_memory=False) |
| validation_data = pd.read_csv("./ressources/dev.csv", low_memory=False) |
| test_data = pd.read_csv("./ressources/test.csv", low_memory=False) |
|
|
| X_train, y_train = train_data["path"], train_data["sentence"] |
| X_val, y_val = validation_data["path"], validation_data["sentence"] |
| X_test, y_test = test_data["path"], test_data["sentence"] |
|
|
| del train_data, validation_data, test_data |
|
|
|
|
| audio_location = os.environ.get("AUDIO_LOCATION") |
|
|
|
|
| |
| |
| |
| def collate_fn(batch): |
| batch = [b for b in batch if b is not None] |
| if len(batch) == 0: |
| return None |
|
|
| transcriptions, waveforms, audio_lengths = zip(*batch) |
|
|
| transcriptions = [torch.tensor(t, dtype=torch.long) for t in transcriptions] |
| waveforms = [torch.tensor(w, dtype=torch.float32) for w in waveforms] |
|
|
| transcription_lengths = torch.tensor( |
| [t.size(0) for t in transcriptions], dtype=torch.int32 |
| ) |
| audio_lengths = torch.tensor(audio_lengths, dtype=torch.int32) |
|
|
| padded_waveforms = pad_sequence(waveforms, batch_first=True, padding_value=0.0) |
| padded_transcriptions = pad_sequence( |
| transcriptions, batch_first=True, padding_value=0 |
| ) |
|
|
| return padded_waveforms, padded_transcriptions, audio_lengths, transcription_lengths |
|
|
|
|
| |
| |
| |
| class AudioDataset(Dataset): |
| def __init__(self, X, y, audio_location=audio_location, train=False): |
| self.audio_dirs = X.reset_index(drop=True) |
| self.transcriptions = y.reset_index(drop=True) |
| self.train = train |
| self.audio_location = audio_location |
|
|
| self.target_sr = 16000 |
| self.resampler = None |
|
|
| def __len__(self): |
| return len(self.transcriptions) |
|
|
| def __getitem__(self, idx): |
| paths = str(self.audio_dirs[idx]).split(",") |
|
|
| if self.train: |
| chosen = random.randint(0, len(paths) - 1) |
| else: |
| chosen = 0 |
|
|
| audio_location = f"{self.audio_location}/{paths[chosen]}.mp3" |
|
|
| |
| transcription = sp.Encode(self.transcriptions[idx], out_type=int) |
|
|
| |
| waveform, sr = torchaudio.load(audio_location) |
|
|
| |
| if waveform.size(0) > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
|
|
| if sr != self.target_sr: |
| if self.resampler is None or self.resampler.orig_freq != sr: |
| self.resampler = Resample(orig_freq=sr, new_freq=self.target_sr) |
| waveform = self.resampler(waveform) |
|
|
| waveform = waveform.squeeze(0) |
|
|
| return transcription, waveform, waveform.size(0) |
|
|
|
|
| |
| |
| |
| train_data = AudioDataset(X_train, y_train, train=True) |
| validation_data = AudioDataset(X_val, y_val) |
| test_data = AudioDataset(X_test, y_test) |
|
|
| |
| |
| |
| train_dataloader = DataLoader( |
| train_data, |
| shuffle=True, |
| drop_last=True, |
| batch_size=64, |
| num_workers=8, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| persistent_workers=True, |
| ) |
|
|
| validation_dataloader = DataLoader( |
| validation_data, |
| batch_size=64, |
| num_workers=4, |
| collate_fn=collate_fn, |
| persistent_workers=True, |
| ) |
|
|
| test_dataloader = DataLoader( |
| test_data, |
| batch_size=4, |
| num_workers=4, |
| collate_fn=collate_fn, |
| ) |
|
|