ASR / train_trial.py
SIDD2201's picture
Upload 363 files
f2688f7 verified
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler
import os
import matplotlib.pyplot as plt
from worker_module import main_worker
def main():
noisy_dir = "/home/siddharth/Sid/ASR/ANC/Babble_noise_speech_train_trim_og"
clean_dir = "/home/siddharth/Sid/ASR/ANC/clean_speech_train_trim_og"
save_dir = "/home/siddharth/Sid/ASR/ANC/preprocessed_trimmed_data_og"
num_epochs = 1500
learning_rate = 0.01
batch_size = 16
world_size = torch.cuda.device_count()
checkpoint_path=None
# checkpoint_path = "/home/siddharth/Sid/ASR/ANC/DEEP_ANC_MODEL_best_bs16_lr0.001_ep1500_og.pth" # Path to the saved model checkpoint
mp.spawn(main_worker, args=(world_size, noisy_dir, clean_dir, save_dir, num_epochs, learning_rate, batch_size, checkpoint_path), nprocs=world_size, join=True)
if __name__ == "__main__":
main()