File size: 2,912 Bytes
ba80248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import numpy as np
import pandas as pd
import librosa
import torchaudio as audio

class SEP28KDataset(torch.utils.data.Dataset):
    """SEP-28k Dataset."""

    def __init__(self, x, y, unsqueeze=False, transform=None):
        """
        Args:
            x (hdf5): hdf5 data one of 'Xtrain', 'Xtest', or 'Xvalid'
            y (hdf5): hdf5 file one of 'Ytrain', 'Ytest', or 'Yvalid'
            unsqueeze (bool, Optional): Whether or not to unsqueeze the feature.
              May be required for models that require image-like inputs.
            transform (callable, Optional): Optional transform to be applied
                on a sample.
        """
        self.data = x
        self.labels = y
        # self.spec = audio.transforms.MelSpectrogram(n_mels=80, sample_rate=16000,
        #                                       n_fft=512, f_max=8000, f_min=0,
        #                                       power=0.5, hop_length=152, win_length=480)
        # self.db = audio.transforms.AmplitudeToDB()

        # self.freq_mask = audio.transforms.FrequencyMasking(freq_mask_param=1)
        # self.time_mask = audio.transforms.TimeMasking(time_mask_param=20)

        # self.rng = np.random.default_rng(42)
        # self.rng_2 = np.random.default_rng(68)
        self.unsqueeze = unsqueeze

        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # load sliced clip
        # _, wav = wavfile.read(clip_path)
        wav = self.data[idx]
        wav = self.pad_trunc(wav, 3000, 16000).astype('float32')

        wav = torch.tensor(wav)
        #wav = self.spec(wav)
        #wav = self.db(wav)

        #if (self.rng.choice(2,p=[0.2,0.8])):
        #  wav = self.freq_mask(wav)

        # if (self.rng_2.choice(2,p=[0.2,0.8])):
        #   wav = self.time_mask(wav)

        # get labels
        labels = self.labels[idx].astype('float32')

        if self.transform is not None:
            wav = self.transform(wav)
        
        if (self.unsqueeze):
            wav = torch.unsqueeze(wav, 0)

        return torch.tensor(wav).clone().detach(), torch.tensor(labels).clone().detach()
        
    @staticmethod
    def pad_trunc(sig, max_ms, sr):
      sig_len = sig.shape[0]
      max_len = sr//1000 * max_ms

      if (sig_len > max_len):
        # Truncate the signal to the given length
        sig = sig[:,:max_len]

      elif (sig_len < max_len):
        # Length of padding to add at the beginning and end of the signal
        pad_begin_len = np.random.randint(0, max_len - sig_len)
        pad_end_len = max_len - sig_len - pad_begin_len

        # Pad with 0s
        pad_begin = np.zeros((pad_begin_len))
        pad_end = np.zeros((pad_end_len))

        sig = np.concatenate((pad_begin, sig, pad_end), 0)
        
      return sig