ASR / Pre_processing.py
SIDD2201's picture
Upload 363 files
f2688f7 verified
import os
import torchaudio
import torch
import numpy as np
import soundfile
class AudioLoader:
def __init__(self, sample_rate=16000):
self.sample_rate = sample_rate
def load_audio(self, file_path):
audio, sample_rate = torchaudio.load(file_path, backend='soundfile')
if sample_rate != self.sample_rate:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)(audio)
return audio.squeeze(0)
class STFT:
def __init__(self, n_fft=1024, hop_length=512, win_length=1024):
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
def compute_stft(self, signal):
return torch.stft(signal, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=torch.hamming_window(self.win_length), return_complex=True)
class SpectrogramSaver:
@staticmethod
def save_spectrogram(spectrogram, save_path):
torch.save(spectrogram, save_path)
# class Preprocessing:
# def __init__(self, sample_rate=16000, n_fft=1024, hop_length=512, win_length=1024):
# self.loader = AudioLoader(sample_rate)
# self.stft = STFT(n_fft, hop_length, win_length)
# self.saver = SpectrogramSaver()
# self.fixed_length = None
# def preprocess(self, signal):
# spectrogram = self.stft.compute_stft(signal)
# real = spectrogram.real
# imag = spectrogram.imag
# combined = torch.stack((real, imag), dim=-1) # Shape: (num_frames, num_frequency_bins, 2)
# return combined
# def determine_fixed_length(self, noisy_dir, clean_dir):
# lengths = []
# noisy_files = [os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')]
# clean_files = [os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.endswith('.wav')]
# for noisy_file, clean_file in zip(noisy_files, clean_files):
# noisy_audio = self.loader.load_audio(noisy_file)
# clean_audio = self.loader.load_audio(clean_file)
# noisy_spectrogram = self.preprocess(noisy_audio)
# clean_spectrogram = self.preprocess(clean_audio)
# lengths.append(noisy_spectrogram.shape[1])
# lengths.append(clean_spectrogram.shape[1])
# self.fixed_length = int(np.median(lengths))
# print(f"Determined fixed length: {self.fixed_length}")
# def create_dataset(self, noisy_dir, clean_dir, save_dir):
# if self.fixed_length is None:
# self.determine_fixed_length(noisy_dir, clean_dir)
# noisy_save_dir = os.path.join(save_dir, 'noisy')
# clean_save_dir = os.path.join(save_dir, 'clean')
# if not os.path.exists(noisy_save_dir):
# os.makedirs(noisy_save_dir)
# if not os.path.exists(clean_save_dir):
# os.makedirs(clean_save_dir)
# noisy_files = [os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')]
# clean_files = [os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.endswith('.wav')]
# for noisy_file, clean_file in zip(noisy_files, clean_files):
# noisy_audio = self.loader.load_audio(noisy_file)
# clean_audio = self.loader.load_audio(clean_file)
# noisy_spectrogram = self.preprocess(noisy_audio)
# clean_spectrogram = self.preprocess(clean_audio)
# noisy_spectrogram = self.pad_spectrogram(noisy_spectrogram)
# clean_spectrogram = self.pad_spectrogram(clean_spectrogram)
# noisy_save_path = os.path.join(noisy_save_dir, f"noisy_{os.path.basename(noisy_file).split('.')[0]}.pt")
# clean_save_path = os.path.join(clean_save_dir, f"clean_{os.path.basename(clean_file).split('.')[0]}.pt")
# self.saver.save_spectrogram(noisy_spectrogram, noisy_save_path)
# self.saver.save_spectrogram(clean_spectrogram, clean_save_path)
# def pad_spectrogram(self, spectrogram):
# pad_length = self.fixed_length - spectrogram.shape[1]
# if pad_length > 0:
# pad = torch.zeros((spectrogram.shape[0], pad_length, spectrogram.shape[2]))
# spectrogram = torch.cat((spectrogram, pad), dim=1)
# elif pad_length < 0:
# spectrogram = spectrogram[:, :self.fixed_length, :]
# return spectrogram
class Preprocessing:
def __init__(self, sample_rate, n_fft, hop_length, win_length):
self.sample_rate = sample_rate
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.fixed_length = 86
self.stft = STFT(n_fft, hop_length, win_length)
self.loader = AudioLoader(sample_rate)
def preprocess(self, signal):
# print(f"Signal shape before STFT: {signal.shape}") # Debug statement
if signal.shape[-1] == 0:
print("Encountered zero-length signal, skipping...")
return None # Skip this signal
spectrogram = self.stft.compute_stft(signal)
real = spectrogram.real
imag = spectrogram.imag
return torch.stack((real, imag), dim=-1)
def determine_fixed_length(self, noisy_dir, clean_dir):
lengths = []
for noisy_file, clean_file in zip(sorted(os.listdir(noisy_dir)), sorted(os.listdir(clean_dir))):
noisy_audio = self.loader.load_audio(os.path.join(noisy_dir, noisy_file))
clean_audio = self.loader.load_audio(os.path.join(clean_dir, clean_file))
# print(f"Noisy audio shape: {noisy_audio.shape}, Clean audio shape: {clean_audio.shape}") # Debug statement
noisy_spectrogram = self.preprocess(noisy_audio)
clean_spectrogram = self.preprocess(clean_audio)
if noisy_spectrogram is None or clean_spectrogram is None:
continue # Skip any zero-length signals
lengths.append(noisy_spectrogram.shape[1])
lengths.append(clean_spectrogram.shape[1])
if lengths:
self.fixed_length = min(lengths)
print(f"Determined fixed length: {self.fixed_length}") # Debug statement
else:
print("No valid spectrograms found.") # If no valid data is found
def create_dataset(self, noisy_dir, clean_dir, save_dir):
if self.fixed_length is None:
self.determine_fixed_length(noisy_dir, clean_dir)
noisy_save_dir = os.path.join(save_dir, 'noisy')
clean_save_dir = os.path.join(save_dir, 'clean')
os.makedirs(noisy_save_dir, exist_ok=True)
os.makedirs(clean_save_dir, exist_ok=True)
for noisy_file, clean_file in zip(sorted(os.listdir(noisy_dir)), sorted(os.listdir(clean_dir))):
noisy_audio = self.loader.load_audio(os.path.join(noisy_dir, noisy_file))
clean_audio = self.loader.load_audio(os.path.join(clean_dir, clean_file))
noisy_spectrogram = self.preprocess(noisy_audio)
clean_spectrogram = self.preprocess(clean_audio)
if noisy_spectrogram is None or clean_spectrogram is None:
continue # Skip any zero-length signals
noisy_spectrogram = noisy_spectrogram[:, :self.fixed_length, :]
clean_spectrogram = clean_spectrogram[:, :self.fixed_length, :]
torch.save(noisy_spectrogram, os.path.join(noisy_save_dir, os.path.basename(noisy_file).replace('.wav', '.pt')))
torch.save(clean_spectrogram, os.path.join(clean_save_dir, os.path.basename(clean_file).replace('.wav', '.pt')))
# print(f"Processed and saved {noisy_file} and {clean_file}") # Debug statement
# # Example usage
# if __name__ == "__main__":
# noisy_dir = "/home/siddharth/Myprojects/ASR_project/Hybrid_CRN_SFANC-FxNLMS/Babble_noise_speech_train"
# clean_dir = "/home/siddharth/Myprojects/ASR_project/Hybrid_CRN_SFANC-FxNLMS/clean_train"
# save_dir = "/home/siddharth/Myprojects/ASR_project/Hybrid_CRN_SFANC-FxNLMS/preprocessed_data"
# preprocessor = Preprocessing(sample_rate=16000, n_fft=1024, hop_length=512, win_length=1024)
# preprocessor.create_dataset(noisy_dir, clean_dir, save_dir)