File size: 8,246 Bytes
f2688f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import torchaudio
import torch
import numpy as np
import soundfile
class AudioLoader:
    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate

    def load_audio(self, file_path):
        audio, sample_rate = torchaudio.load(file_path, backend='soundfile')
        if sample_rate != self.sample_rate:
            audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)(audio)
        return audio.squeeze(0)

class STFT:
    def __init__(self, n_fft=1024, hop_length=512, win_length=1024):
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length

    def compute_stft(self, signal):
        return torch.stft(signal, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=torch.hamming_window(self.win_length), return_complex=True)

class SpectrogramSaver:
    @staticmethod
    def save_spectrogram(spectrogram, save_path):
        torch.save(spectrogram, save_path)

# class Preprocessing:
#     def __init__(self, sample_rate=16000, n_fft=1024, hop_length=512, win_length=1024):
#         self.loader = AudioLoader(sample_rate)
#         self.stft = STFT(n_fft, hop_length, win_length)
#         self.saver = SpectrogramSaver()
#         self.fixed_length = None

#     def preprocess(self, signal):
#         spectrogram = self.stft.compute_stft(signal)
#         real = spectrogram.real
#         imag = spectrogram.imag
#         combined = torch.stack((real, imag), dim=-1)  # Shape: (num_frames, num_frequency_bins, 2)
#         return combined

#     def determine_fixed_length(self, noisy_dir, clean_dir):
#         lengths = []
#         noisy_files = [os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')]
#         clean_files = [os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.endswith('.wav')]

#         for noisy_file, clean_file in zip(noisy_files, clean_files):
#             noisy_audio = self.loader.load_audio(noisy_file)
#             clean_audio = self.loader.load_audio(clean_file)

#             noisy_spectrogram = self.preprocess(noisy_audio)
#             clean_spectrogram = self.preprocess(clean_audio)

#             lengths.append(noisy_spectrogram.shape[1])
#             lengths.append(clean_spectrogram.shape[1])

#         self.fixed_length = int(np.median(lengths))
#         print(f"Determined fixed length: {self.fixed_length}")

#     def create_dataset(self, noisy_dir, clean_dir, save_dir):
#         if self.fixed_length is None:
#             self.determine_fixed_length(noisy_dir, clean_dir)

#         noisy_save_dir = os.path.join(save_dir, 'noisy')
#         clean_save_dir = os.path.join(save_dir, 'clean')
        
#         if not os.path.exists(noisy_save_dir):
#             os.makedirs(noisy_save_dir)
#         if not os.path.exists(clean_save_dir):
#             os.makedirs(clean_save_dir)

#         noisy_files = [os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')]
#         clean_files = [os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.endswith('.wav')]

#         for noisy_file, clean_file in zip(noisy_files, clean_files):
#             noisy_audio = self.loader.load_audio(noisy_file)
#             clean_audio = self.loader.load_audio(clean_file)

#             noisy_spectrogram = self.preprocess(noisy_audio)
#             clean_spectrogram = self.preprocess(clean_audio)

#             noisy_spectrogram = self.pad_spectrogram(noisy_spectrogram)
#             clean_spectrogram = self.pad_spectrogram(clean_spectrogram)

#             noisy_save_path = os.path.join(noisy_save_dir, f"noisy_{os.path.basename(noisy_file).split('.')[0]}.pt")
#             clean_save_path = os.path.join(clean_save_dir, f"clean_{os.path.basename(clean_file).split('.')[0]}.pt")

#             self.saver.save_spectrogram(noisy_spectrogram, noisy_save_path)
#             self.saver.save_spectrogram(clean_spectrogram, clean_save_path)

#     def pad_spectrogram(self, spectrogram):
#         pad_length = self.fixed_length - spectrogram.shape[1]
#         if pad_length > 0:
#             pad = torch.zeros((spectrogram.shape[0], pad_length, spectrogram.shape[2]))
#             spectrogram = torch.cat((spectrogram, pad), dim=1)
#         elif pad_length < 0:
#             spectrogram = spectrogram[:, :self.fixed_length, :]
#         return spectrogram
class Preprocessing:
    def __init__(self, sample_rate, n_fft, hop_length, win_length):
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.fixed_length = 86
        self.stft = STFT(n_fft, hop_length, win_length)
        self.loader = AudioLoader(sample_rate)

    def preprocess(self, signal):
        # print(f"Signal shape before STFT: {signal.shape}")  # Debug statement
        if signal.shape[-1] == 0:
            print("Encountered zero-length signal, skipping...")
            return None  # Skip this signal
        spectrogram = self.stft.compute_stft(signal)
        real = spectrogram.real
        imag = spectrogram.imag
        return torch.stack((real, imag), dim=-1)

    def determine_fixed_length(self, noisy_dir, clean_dir):
        lengths = []
        for noisy_file, clean_file in zip(sorted(os.listdir(noisy_dir)), sorted(os.listdir(clean_dir))):
            noisy_audio = self.loader.load_audio(os.path.join(noisy_dir, noisy_file))
            clean_audio = self.loader.load_audio(os.path.join(clean_dir, clean_file))

            # print(f"Noisy audio shape: {noisy_audio.shape}, Clean audio shape: {clean_audio.shape}")  # Debug statement

            noisy_spectrogram = self.preprocess(noisy_audio)
            clean_spectrogram = self.preprocess(clean_audio)

            if noisy_spectrogram is None or clean_spectrogram is None:
                continue  # Skip any zero-length signals

            lengths.append(noisy_spectrogram.shape[1])
            lengths.append(clean_spectrogram.shape[1])
        
        if lengths:
            self.fixed_length = min(lengths)
            print(f"Determined fixed length: {self.fixed_length}")  # Debug statement
        else:
            print("No valid spectrograms found.")  # If no valid data is found

    def create_dataset(self, noisy_dir, clean_dir, save_dir):
        if self.fixed_length is None:
            self.determine_fixed_length(noisy_dir, clean_dir)

        noisy_save_dir = os.path.join(save_dir, 'noisy')
        clean_save_dir = os.path.join(save_dir, 'clean')
        os.makedirs(noisy_save_dir, exist_ok=True)
        os.makedirs(clean_save_dir, exist_ok=True)

        for noisy_file, clean_file in zip(sorted(os.listdir(noisy_dir)), sorted(os.listdir(clean_dir))):
            noisy_audio = self.loader.load_audio(os.path.join(noisy_dir, noisy_file))
            clean_audio = self.loader.load_audio(os.path.join(clean_dir, clean_file))

            noisy_spectrogram = self.preprocess(noisy_audio)
            clean_spectrogram = self.preprocess(clean_audio)

            if noisy_spectrogram is None or clean_spectrogram is None:
                continue  # Skip any zero-length signals

            noisy_spectrogram = noisy_spectrogram[:, :self.fixed_length, :]
            clean_spectrogram = clean_spectrogram[:, :self.fixed_length, :]

            torch.save(noisy_spectrogram, os.path.join(noisy_save_dir, os.path.basename(noisy_file).replace('.wav', '.pt')))
            torch.save(clean_spectrogram, os.path.join(clean_save_dir, os.path.basename(clean_file).replace('.wav', '.pt')))

            # print(f"Processed and saved {noisy_file} and {clean_file}")  # Debug statement


# # Example usage
# if __name__ == "__main__":
#     noisy_dir = "/home/siddharth/Myprojects/ASR_project/Hybrid_CRN_SFANC-FxNLMS/Babble_noise_speech_train"
#     clean_dir = "/home/siddharth/Myprojects/ASR_project/Hybrid_CRN_SFANC-FxNLMS/clean_train"
#     save_dir = "/home/siddharth/Myprojects/ASR_project/Hybrid_CRN_SFANC-FxNLMS/preprocessed_data"

#     preprocessor = Preprocessing(sample_rate=16000, n_fft=1024, hop_length=512, win_length=1024)
#     preprocessor.create_dataset(noisy_dir, clean_dir, save_dir)