import os import random import torch from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence import pandas as pd import sentencepiece as spm import torchaudio from torchaudio.transforms import Resample # ------------------------- # Tokenizer # ------------------------- sp = spm.SentencePieceProcessor() sp.Load("./ressources/tokenizer/128_v7.model") # ------------------------- # Load CSVs # ------------------------- train_data = pd.read_csv("./ressources/train.csv", low_memory=False) validation_data = pd.read_csv("./ressources/dev.csv", low_memory=False) test_data = pd.read_csv("./ressources/test.csv", low_memory=False) X_train, y_train = train_data["path"], train_data["sentence"] X_val, y_val = validation_data["path"], validation_data["sentence"] X_test, y_test = test_data["path"], test_data["sentence"] del train_data, validation_data, test_data audio_location = os.environ.get("AUDIO_LOCATION") # ------------------------- # Collate Function # ------------------------- def collate_fn(batch): batch = [b for b in batch if b is not None] if len(batch) == 0: return None transcriptions, waveforms, audio_lengths = zip(*batch) transcriptions = [torch.tensor(t, dtype=torch.long) for t in transcriptions] waveforms = [torch.tensor(w, dtype=torch.float32) for w in waveforms] transcription_lengths = torch.tensor( [t.size(0) for t in transcriptions], dtype=torch.int32 ) audio_lengths = torch.tensor(audio_lengths, dtype=torch.int32) padded_waveforms = pad_sequence(waveforms, batch_first=True, padding_value=0.0) padded_transcriptions = pad_sequence( transcriptions, batch_first=True, padding_value=0 ) return padded_waveforms, padded_transcriptions, audio_lengths, transcription_lengths # ------------------------- # Dataset # ------------------------- class AudioDataset(Dataset): def __init__(self, X, y, audio_location=audio_location, train=False): self.audio_dirs = X.reset_index(drop=True) self.transcriptions = y.reset_index(drop=True) self.train = train self.audio_location = audio_location self.target_sr = 16000 self.resampler = None def __len__(self): return len(self.transcriptions) def __getitem__(self, idx): paths = str(self.audio_dirs[idx]).split(",") if self.train: chosen = random.randint(0, len(paths) - 1) else: chosen = 0 audio_location = f"{self.audio_location}/{paths[chosen]}.mp3" # ---- Text ---- transcription = sp.Encode(self.transcriptions[idx], out_type=int) # ---- Audio ---- waveform, sr = torchaudio.load(audio_location) # Convert to mono if waveform.size(0) > 1: waveform = waveform.mean(dim=0, keepdim=True) if sr != self.target_sr: if self.resampler is None or self.resampler.orig_freq != sr: self.resampler = Resample(orig_freq=sr, new_freq=self.target_sr) waveform = self.resampler(waveform) waveform = waveform.squeeze(0) # [T] return transcription, waveform, waveform.size(0) # ------------------------- # Datasets # ------------------------- train_data = AudioDataset(X_train, y_train, train=True) validation_data = AudioDataset(X_val, y_val) test_data = AudioDataset(X_test, y_test) # ------------------------- # DataLoaders # ------------------------- train_dataloader = DataLoader( train_data, shuffle=True, drop_last=True, batch_size=64, num_workers=8, collate_fn=collate_fn, pin_memory=True, persistent_workers=True, ) validation_dataloader = DataLoader( validation_data, batch_size=64, num_workers=4, collate_fn=collate_fn, persistent_workers=True, ) test_dataloader = DataLoader( test_data, batch_size=4, num_workers=4, collate_fn=collate_fn, )