In [1]:
import os
import torch
import torchaudio
import soundfile as sf
from torch.utils.data import DataLoader, Dataset
from Pre_processing_test import Preprocessing
from Deep_ANC_model_trim import CRN
import numpy as np
from collections import defaultdict
import shutil

In [2]:

MEDIAN_LENGTH = 44000 # Set this to the determined median length from your training data
TEMP_DIR = "temp_padded_files" # Temporary directory to store padded files
FIXED_SEGMENT_LENGTH = 86 # Fixed length for segments

class NoisySpeechTestDataset(Dataset):
 def __init__(self, noisy_dir, preprocessor):
 self.noisy_dir = noisy_dir
 self.preprocessor = preprocessor
 self.noisy_files = [os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.pt')]
 self.noisy_files.sort() # Ensure the files are sorted
 print(f"Loaded {len(self.noisy_files)} segments.")

 def __len__(self):
 return len(self.noisy_files)

 def __getitem__(self, idx):
 noisy_file = self.noisy_files[idx]
 noisy_spectrogram = torch.load(noisy_file)
 print(f"Fetching segment: {noisy_file}")
 return noisy_spectrogram, noisy_file

In [3]:
def custom_collate(batch):
 segments, basenames = zip(*batch)
 return torch.cat(segments, dim=0), basenames

In [4]:
def reassemble_segments(segments, segment_files, original_lengths):
 segments_dict = {}
 for segment, file_path, original_length in zip(segments, segment_files, original_lengths):
 base_name = os.path.basename(file_path).split('_segment')[0]
 segment_number = int(os.path.basename(file_path).split('_segment')[1].split('.')[0])
 if base_name not in segments_dict:
 segments_dict[base_name] = []
 segments_dict[base_name].append((segment_number, segment, original_length))
 
 reassembled_audios = []
 for base_name, segs in segments_dict.items():
 segs = sorted(segs, key=lambda x: x[0]) # Sort by segment number
 audio_segments = []
 for i, seg in enumerate(segs):
 segment_length = seg[2]
 if i == len(segs) - 1: # Trim the last segment back to its original length
 audio_segment = seg[1][:segment_length]
 else:
 audio_segment = seg[1]
 audio_segments.append(audio_segment)
 reassembled_audio = torch.cat(audio_segments, dim=0)
 reassembled_audios.append((base_name, reassembled_audio))
 print(f"Reassembled audio length for {base_name}: {reassembled_audio.shape[0]}") # Debug statement
 return reassembled_audios

In [5]:
def save_audio(audio, path, sample_rate):
 sf.write(path, audio.cpu().numpy(), sample_rate)


In [6]:
def process_segments(model, segments, device, fixed_length):
 model.eval()
 cleaned_segments = []
 segment_files = []
 original_lengths = []

 with torch.no_grad():
 for segment, file_path in segments:
 current_length = segment.size(1)
 original_lengths.append(current_length)
 print(f"Original length of segment: {current_length}")

 segment = segment.unsqueeze(0).to(device) # Add batch dimension and move to device
 print(f"Segment shape before model: {segment.shape}")

 cleaned_spectrogram = model(segment).squeeze(0)
 print(f"Cleaned spectrogram shape: {cleaned_spectrogram.shape}")

 # Reconstruct complex spectrogram
 real_part = cleaned_spectrogram[..., 0]
 imag_part = cleaned_spectrogram[..., 1]
 complex_spectrogram = torch.view_as_complex(torch.stack((real_part, imag_part), dim=-1))

 # Ensure window is on the same device as the spectrogram
 window = torch.hamming_window(1024).to(device)

 # Convert the spectrogram back to waveform using istft
 cleaned_audio = torch.istft(
 complex_spectrogram, n_fft=1024, hop_length=512, win_length=1024, window=window, length=current_length * 512
 )

 print(f"Cleaned audio length: {cleaned_audio.shape[0]}")

 cleaned_segments.append(cleaned_audio) 
 segment_files.append(file_path) # Ensure file_path is correctly stored as a string

 return cleaned_segments, segment_files, original_lengths

In [7]:
def main():
 noisy_test_dir = "/home/siddharth/Sid/ASR/ANC/Babble_noise_speech_test_trim"
 preprocessed_test_dir = "/home/siddharth/Sid/ASR/ANC/Pre_processed_test_data"
 save_clean_dir = "/home/siddharth/Sid/ASR/ANC/cleaned_speeches"
 model_path = "/home/siddharth/Sid/ASR/ANC/DEEP_ANC_MODEL_best_bs16_lr0.001_ep1500_og.pth"

 if not os.path.exists(save_clean_dir):
 os.makedirs(save_clean_dir)

 preprocessor = Preprocessing(sample_rate=16000, n_fft=1024, hop_length=512, win_length=1024)
 if not os.path.exists(preprocessed_test_dir):
 os.makedirs(preprocessed_test_dir)

 # Create a temporary directory for padded files
 if os.path.exists(TEMP_DIR):
 shutil.rmtree(TEMP_DIR)
 os.makedirs(TEMP_DIR)

 # Load and pad noisy signals before processing
 noisy_files = [os.path.join(noisy_test_dir, f) for f in os.listdir(noisy_test_dir) if f.endswith('.wav')]
 for file in noisy_files:
 signal, sr = sf.read(file)
 length = len(signal)
 if length < MEDIAN_LENGTH:
 padded_signal = np.pad(signal, (0, MEDIAN_LENGTH - length), 'constant')
 elif length > MEDIAN_LENGTH:
 padded_signal = signal[:MEDIAN_LENGTH]
 else:
 padded_signal = signal
 temp_file_path = os.path.join(TEMP_DIR, os.path.basename(file))
 sf.write(temp_file_path, padded_signal, sr) # Save the padded signal to the temporary directory

 # Create the dataset using the padded signals in the temporary directory
 preprocessor.create_dataset(TEMP_DIR, preprocessed_test_dir)
 fixed_length = preprocessor.fixed_length

 test_dataset = NoisySpeechTestDataset(os.path.join(preprocessed_test_dir, 'noisy'), preprocessor)
 test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 model = CRN().to(device)

 # Load the DDP-trained model and remove the "module." prefix
 state_dict = torch.load(model_path, map_location=device)
 new_state_dict = {}
 for k, v in state_dict.items():
 if k.startswith("module."):
 new_state_dict[k[7:]] = v # Remove "module." prefix
 else:
 new_state_dict[k] = v
 model.load_state_dict(new_state_dict)

 all_segments = defaultdict(list)

 for noisy_spectrogram, noisy_path in test_loader:
 noisy_spectrogram = noisy_spectrogram.squeeze(0) # Remove batch dimension
 print(f"Shape of noisy_spectrogram: {noisy_spectrogram.shape}")
 # Here noisy_path should be accessed as a string, not a tuple
 noisy_path_str = noisy_path[0] # Extract the string from the tuple

 cleaned_segments, segment_files, original_lengths = process_segments(model, [(noisy_spectrogram, noisy_path_str)], device, fixed_length)
 for cleaned_segment, segment_file, original_length in zip(cleaned_segments, segment_files, original_lengths):
 base_name = os.path.basename(segment_file).split('_segment')[0]
 all_segments[base_name].append((cleaned_segment, segment_file, original_length))

 for base_name, segments in all_segments.items():
 reassembled_audios = reassemble_segments([seg[0] for seg in segments], [seg[1] for seg in segments], [seg[2] for seg in segments])
 for base_name, audio in reassembled_audios:
 save_path = os.path.join(save_clean_dir, f"{base_name}_cleaned.wav")
 save_audio(audio, save_path, preprocessor.loader.sample_rate)
 print(f"Saved cleaned audio to {save_path}")

In [8]:
if __name__ == "__main__":
 main()

Determined fixed length: 86
Loaded 5285 segments.


 state_dict = torch.load(model_path, map_location=device)
 noisy_spectrogram = torch.load(noisy_file)


Fetching segment: /home/siddharth/Sid/ASR/ANC/Pre_processed_test_data/noisy/noisy_Babble_clnsp0_segment0.pt
Shape of noisy_spectrogram: torch.Size([513, 86, 2])
Original length of segment: 86
Segment shape before model: torch.Size([1, 513, 86, 2])
Cleaned spectrogram shape: torch.Size([513, 96, 2])
Cleaned audio length: 44032
Fetching segment: /home/siddharth/Sid/ASR/ANC/Pre_processed_test_data/noisy/noisy_Babble_clnsp0_segment1.pt
Shape of noisy_spectrogram: torch.Size([513, 86, 2])
Original length of segment: 86
Segment shape before model: torch.Size([1, 513, 86, 2])
Cleaned spectrogram shape: torch.Size([513, 96, 2])
Cleaned audio length: 44032
Fetching segment: /home/siddharth/Sid/ASR/ANC/Pre_processed_test_data/noisy/noisy_Babble_clnsp0_segment2.pt
Shape of noisy_spectrogram: torch.Size([513, 86, 2])
Original length of segment: 86
Segment shape before model: torch.Size([1, 513, 86, 2])
Cleaned spectrogram shape: torch.Size([513, 96, 2])
Cleaned audio length: 44032
Fetching segment