File size: 1,060 Bytes
f2688f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler
import os
import matplotlib.pyplot as plt
from worker_module import main_worker
def main():
    noisy_dir = "/home/siddharth/Sid/ASR/ANC/Babble_noise_speech_train_trim_og"
    clean_dir = "/home/siddharth/Sid/ASR/ANC/clean_speech_train_trim_og"
    save_dir = "/home/siddharth/Sid/ASR/ANC/preprocessed_trimmed_data_og"

    num_epochs = 1500  
    learning_rate = 0.01
    batch_size = 16

    world_size = torch.cuda.device_count()
    checkpoint_path=None
    # checkpoint_path = "/home/siddharth/Sid/ASR/ANC/DEEP_ANC_MODEL_best_bs16_lr0.001_ep1500_og.pth"  # Path to the saved model checkpoint
    mp.spawn(main_worker, args=(world_size, noisy_dir, clean_dir, save_dir, num_epochs, learning_rate, batch_size, checkpoint_path), nprocs=world_size, join=True)
if __name__ == "__main__":
    main()