import os from typing import Dict, Literal, Tuple import ffmpeg import numpy as np import pandas as pd import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torchvision import transforms from bytecover.models.data_model import BatchDict from bytecover.utils import bcolors class ByteCoverDataset(Dataset): def __init__( self, data_path: str, file_ext: str, dataset_path: str, data_split: Literal["TRAIN", "VAL", "TEST"], debug: bool, target_sr: int, max_len: int, ) -> None: super().__init__() self.data_path = data_path self.file_ext = file_ext self.dataset_path = dataset_path self.data_split = data_split self.debug = debug self.target_sr = target_sr self.max_len = max_len self._load_data() self.pipeline = transforms.Compose([self._read_audio, self._pad_or_trim_audio]) def __len__(self) -> int: return len(self.track_ids) def __getitem__(self, index: int) -> BatchDict: track_id = self.track_ids[index] anchor_audio = self.pipeline(track_id) clique_id, pos_id, neg_id = self._triplet_sampling(track_id) if self.data_split == "TRAIN": positive_audio = self.pipeline(pos_id) negative_audio = self.pipeline(neg_id) else: positive_audio = torch.empty(0) negative_audio = torch.empty(0) return dict( anchor_id=track_id, anchor=anchor_audio, anchor_label=torch.tensor(clique_id, dtype=torch.float), positive_id=pos_id, positive=positive_audio, negative_id=neg_id, negative=negative_audio, ) def _triplet_sampling(self, track_id: str) -> Tuple[int, str, str]: clique_id = self.labels.loc[track_id, "clique"] versions = self.versions.loc[clique_id, "versions"] np.random.shuffle(versions) pos_list = np.setdiff1d(versions, track_id) pos_id = np.random.choice(pos_list, 1)[0] neg_id = self.labels[~self.labels.index.isin(versions)].sample(1).index[0] return (clique_id, pos_id, neg_id) def _load_data(self) -> None: self.track_ids = np.load( os.path.join(self.data_path, "splits", f"{self.data_split.lower()}_ids.npy"), allow_pickle=True ) self.labels = pd.read_csv(os.path.join(self.data_path, "interim", "shs100k.csv"), usecols=["clique", "id"]) self.labels = self.labels[self.labels["id"].isin(self.track_ids)] self.labels.dropna(inplace=True) self.labels.set_index("id", inplace=True) cliques = self.labels["clique"].unique() mapping = {} for k, clique in enumerate(cliques): mapping[clique] = k self.labels["clique"] = self.labels["clique"].map(lambda x: mapping[x]) self.versions = pd.read_csv( os.path.join(self.data_path, "interim", "versions.csv"), converters={"versions": eval} ) self.versions.dropna(inplace=True) self.versions = self.versions[self.versions["clique"].isin(cliques)] self.versions["clique"] = self.versions["clique"].map(lambda x: mapping[x]) self.versions.set_index("clique", inplace=True) def _read_audio(self, track_id: str) -> torch.Tensor: if self.debug: seq_len = np.random.randint(10, 200) if self.max_len <= 0 else self.max_len return torch.rand(seq_len * self.target_sr) filename = os.path.join(self.dataset_path, f"{track_id}.{self.file_ext}") try: # This launches a subprocess to decode audio while down-mixing and resampling as necessary. # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. out, _ = ( ffmpeg.input(filename, threads=0) .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=self.target_sr) .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) ) except ffmpeg.Error as e: raise RuntimeError( f"{bcolors.WARNING}Failed to load audio:{bcolors.FAIL + filename + bcolors.ENDC}\n{e.stderr.decode()}" ) from e # int16 ranges between -2^15 and +2^15 (±32768). By convention, floating point audio data is # normalized to the range of [-1.0, 1.0] audio = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 return torch.from_numpy(audio) def _pad_or_trim_audio(self, audio: torch.Tensor) -> torch.Tensor: if self.max_len <= 0: return audio if (self.data_split == "TRAIN") and (audio.shape[-1] <= self.max_len * self.target_sr): return F.pad(audio, (0, self.max_len * self.target_sr - audio.shape[-1])) max_offset = audio.shape[-1] - self.max_len * self.target_sr offset = np.random.randint(max_offset) if max_offset > 0 else 0 return audio[offset : (offset + self.max_len * self.target_sr)] def bytecover_dataloader( data_path: str, file_ext: str, dataset_path: str, data_split: Literal["TRAIN", "VAL", "TEST"], debug: bool, max_len: int, batch_size: int, target_sr: int, **config: Dict, ) -> DataLoader: return DataLoader( ByteCoverDataset(data_path, file_ext, dataset_path, data_split, debug, target_sr=target_sr, max_len=max_len), batch_size=batch_size if max_len > 0 else 1, num_workers=config["num_workers"], shuffle=config["shuffle"], drop_last=config["drop_last"], )