Spaces:
Sleeping
Sleeping
File size: 2,641 Bytes
a3ea780 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
import numpy as np
from torch.utils.data import Dataset
from typing import Sequence
from src.config.config import DatasetConfig
config = DatasetConfig()
class FullTFPatchesDataset(Dataset):
def __init__(self, spectrograms: Sequence[np.ndarray], labels: Sequence[int], config: DatasetConfig = config) -> None:
self.config = config
self.patch_indices = []
for spec_idx, spec in enumerate(spectrograms):
n_frames = spec.shape[0]
label = labels[spec_idx]
if n_frames >= self.config.cnn_input_length:
for start_frame in range(n_frames - self.config.cnn_input_length + 1):
self.patch_indices.append((spec_idx, start_frame, label))
else:
self.patch_indices.append((spec_idx, 0, label))
self.spectrograms = spectrograms
def __len__(self) -> int:
return len(self.patch_indices)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
spec_idx, start_frame, label = self.patch_indices[idx]
spec = self.spectrograms[spec_idx]
n_frames = spec.shape[0]
if n_frames >= self.config.cnn_input_length:
patch = spec[start_frame:start_frame + self.config.cnn_input_length]
else:
pad = self.config.cnn_input_length - n_frames
patch = np.pad(spec, ((0, pad), (0, 0)), mode='constant')
patch = patch[np.newaxis, :, :]
return torch.tensor(patch, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
class RandomPatchDataset(Dataset):
def __init__(self, spectrograms: Sequence[np.ndarray], labels: Sequence[int], config: DatasetConfig = config) -> None:
self.config = config
self.spectrograms = spectrograms
self.labels = labels
def __len__(self) -> int:
return len(self.labels)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
spec = self.spectrograms[idx]
label = self.labels[idx]
n_frames = spec.shape[0]
if n_frames >= self.config.cnn_input_length:
start = np.random.randint(0, n_frames - self.config.cnn_input_length + 1)
patch = spec[start:start + self.config.cnn_input_length]
else:
pad = self.config.cnn_input_length - n_frames
patch = np.pad(spec, ((0, pad), (0, 0)), mode='constant')
patch = patch[np.newaxis, :, :]
return torch.tensor(patch, dtype=torch.float32), torch.tensor(label, dtype=torch.long) |