Spaces:
Paused
Paused
| from torch.utils.data import Dataset | |
| from torchvision.datasets.utils import download_url | |
| from tqdm import tqdm | |
| import pandas as pd | |
| import os | |
| import torch.nn as nn | |
| import torch | |
| class AudioDataset(Dataset): | |
| def __init__(self, root: str, download: bool = True): | |
| self.root = os.path.expanduser(root) | |
| if download: | |
| self.download() | |
| def __getitem__(self, index): | |
| raise NotImplementedError | |
| def download(self): | |
| raise NotImplementedError | |
| def __len__(self): | |
| raise NotImplementedError | |
| class ESC50(AudioDataset): | |
| base_folder = 'ESC-50-master' | |
| url = "https://github.com/karoldvl/ESC-50/archive/master.zip" | |
| filename = "ESC-50-master.zip" | |
| num_files_in_dir = 2000 | |
| audio_dir = 'audio' | |
| label_col = 'category' | |
| file_col = 'filename' | |
| meta = { | |
| 'filename': os.path.join('meta','esc50.csv'), | |
| } | |
| def __init__(self, root, reading_transformations: nn.Module = None, download: bool = True): | |
| super().__init__(root) | |
| self._load_meta() | |
| self.targets, self.audio_paths = [], [] | |
| self.pre_transformations = reading_transformations | |
| print("Loading audio files") | |
| # self.df['filename'] = os.path.join(self.root, self.base_folder, self.audio_dir) + os.sep + self.df['filename'] | |
| self.df['category'] = self.df['category'].str.replace('_',' ') | |
| for _, row in tqdm(self.df.iterrows()): | |
| file_path = os.path.join(self.root, self.base_folder, self.audio_dir, row[self.file_col]) | |
| self.targets.append(row[self.label_col]) | |
| self.audio_paths.append(file_path) | |
| def _load_meta(self): | |
| path = os.path.join(self.root, self.base_folder, self.meta['filename']) | |
| self.df = pd.read_csv(path) | |
| self.class_to_idx = {} | |
| self.classes = [x.replace('_',' ') for x in sorted(self.df[self.label_col].unique())] | |
| for i, category in enumerate(self.classes): | |
| self.class_to_idx[category] = i | |
| def __getitem__(self, index): | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: (image, target) where target is index of the target class. | |
| """ | |
| file_path, target = self.audio_paths[index], self.targets[index] | |
| idx = torch.tensor(self.class_to_idx[target]) | |
| one_hot_target = torch.zeros(len(self.classes)).scatter_(0, idx, 1).reshape(1,-1) | |
| return file_path, target, one_hot_target | |
| def __len__(self): | |
| return len(self.audio_paths) | |
| def download(self): | |
| download_url(self.url, self.root, self.filename) | |
| # extract file | |
| from zipfile import ZipFile | |
| with ZipFile(os.path.join(self.root, self.filename), 'r') as zip: | |
| zip.extractall(path=self.root) | |