ASR / label.py
SIDD2201's picture
Upload 363 files
f2688f7 verified
import torch
import numpy as np
import os
from torch.utils.data import DataLoader
from test_trim import NoisySpeechTestDataset
from Deep_ANC_model_trim import CRN
# Assuming the following classes and functions are defined in your project
# Paths to your pre-processed dataset and pre-trained models
preprocessed_test_dir = "/home/siddharth/Sid/ASR/ANC/Pre_processed_test_data"
models_path = "/home/siddharth/Sid/ASR/ANC/models" # Update this with your actual models path
labels_output_path = "labels.npy" # File to save the labels
# List of model filenames
model_filenames = [f"model_{i}.pth" for i in range(15)] # Assuming models are saved as model_0.pth, model_1.pth, etc.
# Load all models
models = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for model_filename in model_filenames:
model = CRN().to(device)
model_path = os.path.join(models_path, model_filename)
# Load the DDP-trained model and remove the "module." prefix
state_dict = torch.load(model_path, map_location=device)
new_state_dict = {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)
model.eval()
models.append(model)
# Function to calculate SNR improvement
def calculate_snr(noisy, denoised):
signal_power = np.mean(denoised ** 2)
noise_power = np.mean((noisy - denoised) ** 2)
snr = 10 * np.log10(signal_power / noise_power)
return snr
# Function to label the dataset
def label_preprocessed_dataset(preprocessed_test_dir, models):
labels = []
test_dataset = NoisySpeechTestDataset(os.path.join(preprocessed_test_dir, 'noisy'))
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
for noisy_spectrogram, noisy_path in test_loader:
noisy_spectrogram = noisy_spectrogram.squeeze(0) # Remove batch dimension
noisy_path_str = noisy_path[0] # Extract the string from the tuple
best_snr = -np.inf
best_model_idx = -1
for i, model in enumerate(models):
with torch.no_grad():
# Pass the noisy spectrogram through the model
denoised_output = model(noisy_spectrogram.unsqueeze(0).to(device)).squeeze(0)
# Calculate SNR improvement directly on the original noisy data
snr_improvement = calculate_snr(noisy_spectrogram.cpu().numpy(), denoised_output.cpu().numpy())
if snr_improvement > best_snr:
best_snr = snr_improvement
best_model_idx = i
# Save the best model index as the label
labels.append(best_model_idx)
return np.array(labels)
# Main function to run the labeling process
def main():
# Label the pre-processed dataset
labels = label_preprocessed_dataset(preprocessed_test_dir, models)
# Save labels to a file
np.save(labels_output_path, labels)
print(f"Labels saved to {labels_output_path}")
if __name__ == "__main__":
main()