ASR / Pre_processing_test.py
SIDD2201's picture
Upload 363 files
f2688f7 verified
import os
import torchaudio
import torch
import numpy as np
import soundfile
class AudioLoader:
def __init__(self, sample_rate=16000):
self.sample_rate = sample_rate
def load_audio(self, file_path):
audio, sample_rate = torchaudio.load(file_path,backend='soundfile')
if sample_rate != self.sample_rate:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)(audio)
return audio.squeeze(0)
class STFT:
def __init__(self, n_fft=1024, hop_length=512, win_length=1024):
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
def compute_stft(self, signal):
return torch.stft(signal, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=torch.hamming_window(self.win_length), return_complex=True)
class SpectrogramSaver:
@staticmethod
def save_spectrogram(spectrogram, save_path):
torch.save(spectrogram, save_path)
class Preprocessing:
def __init__(self, sample_rate=16000, n_fft=1024, hop_length=512, win_length=1024):
self.loader = AudioLoader(sample_rate)
self.stft = STFT(n_fft, hop_length, win_length)
self.saver = SpectrogramSaver()
self.fixed_length = None
def preprocess(self, signal):
spectrogram = self.stft.compute_stft(signal)
real = spectrogram.real
imag = spectrogram.imag
combined = torch.stack((real, imag), dim=-1) # Shape: (num_frames, num_frequency_bins, 2)
return combined
def determine_fixed_length(self, noisy_dir):
lengths = []
noisy_files = [os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')]
for noisy_file in noisy_files:
noisy_audio = self.loader.load_audio(noisy_file)
noisy_spectrogram = self.preprocess(noisy_audio)
lengths.append(noisy_spectrogram.shape[1])
self.fixed_length = int(np.median(lengths))
print(f"Determined fixed length: {self.fixed_length}")
def create_dataset(self, noisy_dir, save_dir):
if self.fixed_length is None:
self.determine_fixed_length(noisy_dir)
noisy_save_dir = os.path.join(save_dir, 'noisy')
if not os.path.exists(noisy_save_dir):
os.makedirs(noisy_save_dir)
noisy_files = [os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')]
for noisy_file in noisy_files:
noisy_audio = self.loader.load_audio(noisy_file)
noisy_spectrogram = self.preprocess(noisy_audio)
noisy_spectrogram = self.pad_spectrogram(noisy_spectrogram)
noisy_save_path = os.path.join(noisy_save_dir, f"noisy_{os.path.basename(noisy_file).split('.')[0]}.pt")
self.saver.save_spectrogram(noisy_spectrogram, noisy_save_path)
def pad_spectrogram(self, spectrogram):
pad_length = self.fixed_length - spectrogram.shape[1]
if pad_length > 0:
pad = torch.zeros((spectrogram.shape[0], pad_length, spectrogram.shape[2]))
spectrogram = torch.cat((spectrogram, pad), dim=1)
elif pad_length < 0:
spectrogram = spectrogram[:, :self.fixed_length, :]
return spectrogram
# # Example usage for training
# if __name__ == "__main__":
# noisy_dir = "/home/siddharth/Myprojects/ASR_project/Hybrid_CRN_SFANC-FxNLMS/Babble_noise_speech_train"
# save_dir = "/home/siddharth/Myprojects/ASR_project/Hybrid_CRN_SFANC-FxNLMS/preprocessed_data"
# preprocessor = Preprocessing(sample_rate=16000, n_fft=1024, hop_length=512, win_length=1024)
# preprocessor.create_dataset(noisy_dir, save_dir)