import torch import torch.nn as nn import torch.optim as optim import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, Dataset, DistributedSampler import os import matplotlib.pyplot as plt from worker_module import main_worker def main(): noisy_dir = "/home/siddharth/Sid/ASR/ANC/Babble_noise_speech_train_trim_og" clean_dir = "/home/siddharth/Sid/ASR/ANC/clean_speech_train_trim_og" save_dir = "/home/siddharth/Sid/ASR/ANC/preprocessed_trimmed_data_og" num_epochs = 1500 learning_rate = 0.01 batch_size = 16 world_size = torch.cuda.device_count() checkpoint_path=None # checkpoint_path = "/home/siddharth/Sid/ASR/ANC/DEEP_ANC_MODEL_best_bs16_lr0.001_ep1500_og.pth" # Path to the saved model checkpoint mp.spawn(main_worker, args=(world_size, noisy_dir, clean_dir, save_dir, num_epochs, learning_rate, batch_size, checkpoint_path), nprocs=world_size, join=True) if __name__ == "__main__": main()