import torch import torch.nn as nn import torch.optim as optim import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, Dataset, DistributedSampler import os import matplotlib.pyplot as plt from Deep_ANC_model_trim import CRN import logging from Pre_processing import Preprocessing import random from torch.optim.lr_scheduler import CosineAnnealingLR from ranger import Ranger from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts from torch.optim.lr_scheduler import OneCycleLR # from deap import base, creator, tools, algorithms # For GA # import pickle # import json # import optuna # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # class NoisySpeechDataset(Dataset): # def __init__(self, noisy_dir, clean_dir): # self.noisy_files = sorted([os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.pt')]) # self.clean_files = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.endswith('.pt')]) # def __len__(self): # return len(self.noisy_files) # def __getitem__(self, idx): # noisy_spectrogram = torch.load(self.noisy_files[idx], weights_only=True) # clean_spectrogram = torch.load(self.clean_files[idx], weights_only=True) # assert noisy_spectrogram.shape == clean_spectrogram.shape, "Mismatched tensor shapes" # return noisy_spectrogram, clean_spectrogram # class NoisySpeechDataset(Dataset): # def __init__(self, noisy_dir, clean_dir, subset_size=50000): # self.noisy_files = sorted([os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.pt')]) # self.clean_files = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.endswith('.pt')]) # assert len(self.noisy_files) == len(self.clean_files), "Mismatched noisy and clean datasets" # subset_size = min(subset_size, len(self.noisy_files)) # self.noisy_files = self.noisy_files[:subset_size] # self.clean_files = self.clean_files[:subset_size] # def __len__(self): # return len(self.noisy_files) # def __getitem__(self, idx): # noisy_spectrogram = torch.load(self.noisy_files[idx], weights_only=True) # clean_spectrogram = torch.load(self.clean_files[idx], weights_only=True) # return noisy_spectrogram, clean_spectrogram def custom_loss_function(output, target): if output.size() != target.size(): min_size = min(output.size(2), target.size(2)) output = output[:, :, :min_size, :] target = target[:, :, :min_size, :] return torch.mean((output - target) ** 2) class NoisySpeechDataset(Dataset): def __init__(self, noisy_dir, clean_dir, subset_size=50000, shuffle=True): self.noisy_files = sorted([os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.pt')]) self.clean_files = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.endswith('.pt')]) assert len(self.noisy_files) == len(self.clean_files), "Mismatched noisy and clean datasets" # If shuffle is True, shuffle the dataset if shuffle: combined = list(zip(self.noisy_files, self.clean_files)) random.shuffle(combined) self.noisy_files, self.clean_files = zip(*combined) # Limit the subset size if provided subset_size = min(subset_size, len(self.noisy_files)) self.noisy_files = self.noisy_files[:subset_size] self.clean_files = self.clean_files[:subset_size] def __len__(self): return len(self.noisy_files) def __getitem__(self, idx): noisy_spectrogram = torch.load(self.noisy_files[idx], weights_only=True) clean_spectrogram = torch.load(self.clean_files[idx], weights_only=True) return noisy_spectrogram, clean_spectrogram # def snr_improvement(noisy, clean, enhanced): # min_size = min(noisy.size(2), clean.size(2), enhanced.size(2)) # noisy = noisy[:, :, :min_size, :] # clean = clean[:, :, :min_size, :] # enhanced = enhanced[:, :, :min_size, :] # noise = noisy - clean # noise_est = enhanced - clean # snr_before = torch.mean(clean ** 2) / torch.mean(noise ** 2) # snr_after = torch.mean(clean ** 2) / torch.mean(noise_est ** 2) # return 10 * torch.log10(snr_after / snr_before) def snr_improvement(noisy, clean, enhanced): min_size = min(noisy.size(2), clean.size(2), enhanced.size(2)) noisy = noisy[:, :, :min_size, :] clean = clean[:, :, :min_size, :] enhanced = enhanced[:, :, :min_size, :] noise = noisy - clean noise_est = enhanced - clean # Ensure the denominator isn't zero to avoid NaN values noise_power = torch.mean(noise ** 2) noise_est_power = torch.mean(noise_est ** 2) if noise_power == 0 or noise_est_power == 0: return torch.tensor(0.0) # Avoid division by zero and return 0 SNR improvement snr_before = torch.mean(clean ** 2) / noise_power snr_after = torch.mean(clean ** 2) / noise_est_power return 10 * torch.log10(snr_after / snr_before) def plot_metrics(train_metrics, val_metrics, metric_name): epochs = range(1, len(train_metrics) + 1) plt.plot(epochs, train_metrics, 'bo', label=f'Training {metric_name}') plt.plot(epochs, val_metrics, 'b', label=f'Validation {metric_name}') plt.title(f'Training and Validation {metric_name}') plt.xlabel('Epochs') plt.ylabel(metric_name) plt.legend() plt.show() def train_model(rank, world_size, model, train_loader, val_loader, num_epochs, learning_rate, save_path, best_save_path, checkpoint_path=None): try: # Enable anomaly detection torch.autograd.set_detect_anomaly(True) # Set the device for the current process torch.cuda.set_device(rank) model = model.to(rank) model = DDP(model, device_ids=[rank]) # # Apply weight initialization # def init_weights(m): # if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): # nn.init.kaiming_uniform_(m.weight, nonlinearity='relu') # He initialization # if m.bias is not None: # nn.init.constant_(m.bias, 0) # elif isinstance(m, nn.LSTM): # for name, param in m.named_parameters(): # if 'weight' in name: # nn.init.xavier_uniform_(param) # Xavier initialization # elif 'bias' in name: # nn.init.constant_(param, 0) # # Initialize the weights of the model # model.apply(init_weights) # Initialize the Adam optimizer optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True) # Set up the learning rate scheduler # scheduler = OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=1250, epochs=1500) # Define the Cosine Annealing Warm Restarts scheduler scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=20, verbose=True) # Initialize variables for tracking progress start_epoch = 0 best_val_loss = float('inf') best_val_snr_improvement = float('-inf') # Load checkpoint if provided and if it exists if checkpoint_path and os.path.exists(checkpoint_path): try: checkpoint = torch.load(checkpoint_path, map_location=torch.device(f'cuda:{rank}')) print(f"Checkpoint keys: {checkpoint.keys()}") # Inspect the checkpoint keys # Directly load the state dictionary into the model model.load_state_dict(checkpoint) logger.info(f"Model state loaded directly from checkpoint.") # If available, load optimizer and scheduler states if 'optimizer_state_dict' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if 'scheduler_state_dict' in checkpoint: scheduler.load_state_dict(checkpoint['scheduler_state_dict']) # Load additional tracking variables if available start_epoch = checkpoint.get('epoch', 0) + 1 best_val_loss = checkpoint.get('best_val_loss', float('inf')) best_val_snr_improvement = checkpoint.get('best_val_snr_improvement', float('-inf')) logger.info(f"Resuming training from epoch {start_epoch}") except Exception as e: logger.error(f"Error loading checkpoint: {e}") raise e # Re-raise the exception to prevent further issues # Training loop model.train() training_snr_improvements = [] validation_snr_improvements = [] for epoch in range(start_epoch, start_epoch + num_epochs): running_loss = 0.0 train_snr_improvement = 0.0 total_samples = 0 batch_snr_improvements = [] for i, (noisy_spectrogram, clean_spectrogram) in enumerate(train_loader): noisy_spectrogram = noisy_spectrogram.cuda(rank, non_blocking=True) clean_spectrogram = clean_spectrogram.cuda(rank, non_blocking=True) optimizer.zero_grad() # Perform forward and backward pass with anomaly detection with torch.amp.autocast(device_type='cuda'): output = model(noisy_spectrogram) loss = custom_loss_function(output, clean_spectrogram) # Check for NaNs or Infs in loss if torch.isnan(loss).any() or torch.isinf(loss).any(): print(f"NaN or Inf detected in loss at iteration {i}, epoch {epoch}") continue loss.backward() # Apply gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() running_loss += loss.item() # Calculate SNR Improvement batch_snr_improvement = 0.0 for j in range(noisy_spectrogram.size(0)): single_snr_improvement = snr_improvement( noisy_spectrogram[j:j+1], clean_spectrogram[j:j+1], output[j:j+1] ).item() batch_snr_improvement += single_snr_improvement batch_snr_improvement /= noisy_spectrogram.size(0) batch_snr_improvements.append(batch_snr_improvement) total_samples += noisy_spectrogram.size(0) # Aggregate batch SNR improvements training_snr_improvement_avg = sum(batch_snr_improvements) / len(batch_snr_improvements) training_snr_improvements.append(training_snr_improvement_avg) print(f"Epoch {epoch+1}, Training SNR Improvement: {training_snr_improvement_avg}") print(f"Epoch {epoch+1}, Total Samples Processed: {total_samples}") # Validation phase model.eval() val_loss = 0.0 val_snr_improvement = 0.0 with torch.no_grad(): for noisy_spectrogram, clean_spectrogram in val_loader: noisy_spectrogram = noisy_spectrogram.cuda(rank, non_blocking=True) clean_spectrogram = clean_spectrogram.cuda(rank, non_blocking=True) with torch.amp.autocast(device_type='cuda'): output = model(noisy_spectrogram) loss = custom_loss_function(output, clean_spectrogram) val_loss += loss.item() val_snr_improvement += snr_improvement(noisy_spectrogram, clean_spectrogram, output).item() val_loss /= len(val_loader) val_snr_improvement /= len(val_loader) validation_snr_improvements.append(val_snr_improvement) print(f"Epoch {epoch+1}, Validation Loss: {val_loss}, Validation SNR Improvement: {val_snr_improvement}") model.train() # Save the model every 50 epochs if rank == 0: if (epoch + 1) % 50 == 0: torch.save(model.state_dict(), save_path) print(f"Model saved at epoch {epoch+1}") if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), best_save_path) print(f"Best model saved at epoch {epoch+1} with validation loss {best_val_loss}") if val_snr_improvement > best_val_snr_improvement: best_val_snr_improvement = val_snr_improvement # Step the learning rate scheduler scheduler.step(val_loss) if rank == 0: print(f"Training complete for batch size {train_loader.batch_size}, learning rate {learning_rate}, epochs {num_epochs}") print(f"Best Validation Loss: {best_val_loss}, Best Validation SNR Improvement: {best_val_snr_improvement}") plot_metrics(training_snr_improvements, validation_snr_improvements, 'SNR Improvement') except Exception as e: print(f"Rank {rank} encountered an error: {e}") finally: torch.cuda.synchronize() # Ensure all operations are complete before cleanup cleanup() def setup(rank, world_size): logger.info(f"Setting up distributed training on rank {rank}") dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) # torch.distributed.barrier() def cleanup(): try: dist.destroy_process_group() except Exception as e: print(f"Error during cleanup: {e}") def main_worker(rank, world_size, noisy_dir, clean_dir, save_dir, num_epochs, learning_rate, batch_size, checkpoint_path): try: setup(rank, world_size) # Preprocessing the data # preprocessor = Preprocessing(sample_rate=16000, n_fft=1024, hop_length=512, win_length=1024) # preprocessor.create_dataset(noisy_dir, clean_dir, save_dir) dataset = NoisySpeechDataset(os.path.join(save_dir, 'noisy'), os.path.join(save_dir, 'clean'), subset_size=50000) train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False) train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=2) val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler) model = CRN() save_path = f"/home/siddharth/Sid/ASR/ANC/DEEP_ANC_MODEL_trim_bs{batch_size}_lr{learning_rate}_ep{num_epochs}_og_trial.pth" best_save_path = f"/home/siddharth/Sid/ASR/ANC/DEEP_ANC_MODEL_best_bs{batch_size}_lr{learning_rate}_ep{num_epochs}_og_trial.pth" train_model(rank, world_size, model, train_loader, val_loader, num_epochs, learning_rate, save_path, best_save_path, checkpoint_path) except Exception as e: logger.error(f"An error occurred on rank {rank}: {e}") finally: cleanup()