|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import torch.distributed as dist |
|
|
import torch.multiprocessing as mp |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler |
|
|
import os |
|
|
import matplotlib.pyplot as plt |
|
|
from worker_module import main_worker |
|
|
def main(): |
|
|
noisy_dir = "/home/siddharth/Sid/ASR/ANC/Babble_noise_speech_train_trim_og" |
|
|
clean_dir = "/home/siddharth/Sid/ASR/ANC/clean_speech_train_trim_og" |
|
|
save_dir = "/home/siddharth/Sid/ASR/ANC/preprocessed_trimmed_data_og" |
|
|
|
|
|
num_epochs = 1500 |
|
|
learning_rate = 0.01 |
|
|
batch_size = 16 |
|
|
|
|
|
world_size = torch.cuda.device_count() |
|
|
checkpoint_path=None |
|
|
|
|
|
mp.spawn(main_worker, args=(world_size, noisy_dir, clean_dir, save_dir, num_epochs, learning_rate, batch_size, checkpoint_path), nprocs=world_size, join=True) |
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|