|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import torch.distributed as dist |
|
|
import torch.multiprocessing as mp |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler |
|
|
import os |
|
|
import matplotlib.pyplot as plt |
|
|
from Deep_ANC_model_trim import CRN |
|
|
import logging |
|
|
from Pre_processing import Preprocessing |
|
|
import random |
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
from ranger import Ranger |
|
|
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts |
|
|
from torch.optim.lr_scheduler import OneCycleLR |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def custom_loss_function(output, target): |
|
|
if output.size() != target.size(): |
|
|
min_size = min(output.size(2), target.size(2)) |
|
|
output = output[:, :, :min_size, :] |
|
|
target = target[:, :, :min_size, :] |
|
|
return torch.mean((output - target) ** 2) |
|
|
class NoisySpeechDataset(Dataset): |
|
|
def __init__(self, noisy_dir, clean_dir, subset_size=50000, shuffle=True): |
|
|
self.noisy_files = sorted([os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.pt')]) |
|
|
self.clean_files = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.endswith('.pt')]) |
|
|
assert len(self.noisy_files) == len(self.clean_files), "Mismatched noisy and clean datasets" |
|
|
|
|
|
|
|
|
if shuffle: |
|
|
combined = list(zip(self.noisy_files, self.clean_files)) |
|
|
random.shuffle(combined) |
|
|
self.noisy_files, self.clean_files = zip(*combined) |
|
|
|
|
|
|
|
|
subset_size = min(subset_size, len(self.noisy_files)) |
|
|
self.noisy_files = self.noisy_files[:subset_size] |
|
|
self.clean_files = self.clean_files[:subset_size] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.noisy_files) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
noisy_spectrogram = torch.load(self.noisy_files[idx], weights_only=True) |
|
|
clean_spectrogram = torch.load(self.clean_files[idx], weights_only=True) |
|
|
return noisy_spectrogram, clean_spectrogram |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def snr_improvement(noisy, clean, enhanced): |
|
|
min_size = min(noisy.size(2), clean.size(2), enhanced.size(2)) |
|
|
noisy = noisy[:, :, :min_size, :] |
|
|
clean = clean[:, :, :min_size, :] |
|
|
enhanced = enhanced[:, :, :min_size, :] |
|
|
|
|
|
noise = noisy - clean |
|
|
noise_est = enhanced - clean |
|
|
|
|
|
|
|
|
noise_power = torch.mean(noise ** 2) |
|
|
noise_est_power = torch.mean(noise_est ** 2) |
|
|
|
|
|
if noise_power == 0 or noise_est_power == 0: |
|
|
return torch.tensor(0.0) |
|
|
|
|
|
snr_before = torch.mean(clean ** 2) / noise_power |
|
|
snr_after = torch.mean(clean ** 2) / noise_est_power |
|
|
|
|
|
return 10 * torch.log10(snr_after / snr_before) |
|
|
|
|
|
def plot_metrics(train_metrics, val_metrics, metric_name): |
|
|
epochs = range(1, len(train_metrics) + 1) |
|
|
plt.plot(epochs, train_metrics, 'bo', label=f'Training {metric_name}') |
|
|
plt.plot(epochs, val_metrics, 'b', label=f'Validation {metric_name}') |
|
|
plt.title(f'Training and Validation {metric_name}') |
|
|
plt.xlabel('Epochs') |
|
|
plt.ylabel(metric_name) |
|
|
plt.legend() |
|
|
plt.show() |
|
|
|
|
|
|
|
|
def train_model(rank, world_size, model, train_loader, val_loader, num_epochs, learning_rate, save_path, best_save_path, checkpoint_path=None): |
|
|
try: |
|
|
|
|
|
torch.autograd.set_detect_anomaly(True) |
|
|
|
|
|
|
|
|
torch.cuda.set_device(rank) |
|
|
model = model.to(rank) |
|
|
model = DDP(model, device_ids=[rank]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True) |
|
|
|
|
|
|
|
|
|
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', |
|
|
factor=0.1, patience=20, |
|
|
verbose=True) |
|
|
|
|
|
|
|
|
start_epoch = 0 |
|
|
best_val_loss = float('inf') |
|
|
best_val_snr_improvement = float('-inf') |
|
|
|
|
|
|
|
|
if checkpoint_path and os.path.exists(checkpoint_path): |
|
|
try: |
|
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device(f'cuda:{rank}')) |
|
|
print(f"Checkpoint keys: {checkpoint.keys()}") |
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint) |
|
|
logger.info(f"Model state loaded directly from checkpoint.") |
|
|
|
|
|
|
|
|
if 'optimizer_state_dict' in checkpoint: |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
|
|
|
if 'scheduler_state_dict' in checkpoint: |
|
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
|
|
|
|
|
|
start_epoch = checkpoint.get('epoch', 0) + 1 |
|
|
best_val_loss = checkpoint.get('best_val_loss', float('inf')) |
|
|
best_val_snr_improvement = checkpoint.get('best_val_snr_improvement', float('-inf')) |
|
|
logger.info(f"Resuming training from epoch {start_epoch}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading checkpoint: {e}") |
|
|
raise e |
|
|
|
|
|
|
|
|
model.train() |
|
|
training_snr_improvements = [] |
|
|
validation_snr_improvements = [] |
|
|
|
|
|
for epoch in range(start_epoch, start_epoch + num_epochs): |
|
|
running_loss = 0.0 |
|
|
train_snr_improvement = 0.0 |
|
|
total_samples = 0 |
|
|
batch_snr_improvements = [] |
|
|
|
|
|
for i, (noisy_spectrogram, clean_spectrogram) in enumerate(train_loader): |
|
|
noisy_spectrogram = noisy_spectrogram.cuda(rank, non_blocking=True) |
|
|
clean_spectrogram = clean_spectrogram.cuda(rank, non_blocking=True) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
with torch.amp.autocast(device_type='cuda'): |
|
|
output = model(noisy_spectrogram) |
|
|
loss = custom_loss_function(output, clean_spectrogram) |
|
|
|
|
|
|
|
|
if torch.isnan(loss).any() or torch.isinf(loss).any(): |
|
|
print(f"NaN or Inf detected in loss at iteration {i}, epoch {epoch}") |
|
|
continue |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
running_loss += loss.item() |
|
|
|
|
|
|
|
|
batch_snr_improvement = 0.0 |
|
|
for j in range(noisy_spectrogram.size(0)): |
|
|
single_snr_improvement = snr_improvement( |
|
|
noisy_spectrogram[j:j+1], clean_spectrogram[j:j+1], output[j:j+1] |
|
|
).item() |
|
|
batch_snr_improvement += single_snr_improvement |
|
|
|
|
|
batch_snr_improvement /= noisy_spectrogram.size(0) |
|
|
batch_snr_improvements.append(batch_snr_improvement) |
|
|
total_samples += noisy_spectrogram.size(0) |
|
|
|
|
|
|
|
|
training_snr_improvement_avg = sum(batch_snr_improvements) / len(batch_snr_improvements) |
|
|
training_snr_improvements.append(training_snr_improvement_avg) |
|
|
|
|
|
print(f"Epoch {epoch+1}, Training SNR Improvement: {training_snr_improvement_avg}") |
|
|
print(f"Epoch {epoch+1}, Total Samples Processed: {total_samples}") |
|
|
|
|
|
|
|
|
model.eval() |
|
|
val_loss = 0.0 |
|
|
val_snr_improvement = 0.0 |
|
|
with torch.no_grad(): |
|
|
for noisy_spectrogram, clean_spectrogram in val_loader: |
|
|
noisy_spectrogram = noisy_spectrogram.cuda(rank, non_blocking=True) |
|
|
clean_spectrogram = clean_spectrogram.cuda(rank, non_blocking=True) |
|
|
with torch.amp.autocast(device_type='cuda'): |
|
|
output = model(noisy_spectrogram) |
|
|
loss = custom_loss_function(output, clean_spectrogram) |
|
|
|
|
|
val_loss += loss.item() |
|
|
val_snr_improvement += snr_improvement(noisy_spectrogram, clean_spectrogram, output).item() |
|
|
|
|
|
val_loss /= len(val_loader) |
|
|
val_snr_improvement /= len(val_loader) |
|
|
validation_snr_improvements.append(val_snr_improvement) |
|
|
|
|
|
print(f"Epoch {epoch+1}, Validation Loss: {val_loss}, Validation SNR Improvement: {val_snr_improvement}") |
|
|
model.train() |
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
if (epoch + 1) % 50 == 0: |
|
|
torch.save(model.state_dict(), save_path) |
|
|
print(f"Model saved at epoch {epoch+1}") |
|
|
|
|
|
if val_loss < best_val_loss: |
|
|
best_val_loss = val_loss |
|
|
torch.save(model.state_dict(), best_save_path) |
|
|
print(f"Best model saved at epoch {epoch+1} with validation loss {best_val_loss}") |
|
|
|
|
|
if val_snr_improvement > best_val_snr_improvement: |
|
|
best_val_snr_improvement = val_snr_improvement |
|
|
|
|
|
|
|
|
scheduler.step(val_loss) |
|
|
|
|
|
if rank == 0: |
|
|
print(f"Training complete for batch size {train_loader.batch_size}, learning rate {learning_rate}, epochs {num_epochs}") |
|
|
print(f"Best Validation Loss: {best_val_loss}, Best Validation SNR Improvement: {best_val_snr_improvement}") |
|
|
plot_metrics(training_snr_improvements, validation_snr_improvements, 'SNR Improvement') |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Rank {rank} encountered an error: {e}") |
|
|
finally: |
|
|
torch.cuda.synchronize() |
|
|
cleanup() |
|
|
|
|
|
def setup(rank, world_size): |
|
|
logger.info(f"Setting up distributed training on rank {rank}") |
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size) |
|
|
torch.cuda.set_device(rank) |
|
|
|
|
|
def cleanup(): |
|
|
try: |
|
|
dist.destroy_process_group() |
|
|
except Exception as e: |
|
|
print(f"Error during cleanup: {e}") |
|
|
|
|
|
def main_worker(rank, world_size, noisy_dir, clean_dir, save_dir, num_epochs, learning_rate, batch_size, checkpoint_path): |
|
|
try: |
|
|
setup(rank, world_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = NoisySpeechDataset(os.path.join(save_dir, 'noisy'), os.path.join(save_dir, 'clean'), subset_size=50000) |
|
|
|
|
|
train_size = int(0.8 * len(dataset)) |
|
|
val_size = len(dataset) - train_size |
|
|
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) |
|
|
|
|
|
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) |
|
|
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=2) |
|
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler) |
|
|
|
|
|
model = CRN() |
|
|
|
|
|
save_path = f"/home/siddharth/Sid/ASR/ANC/DEEP_ANC_MODEL_trim_bs{batch_size}_lr{learning_rate}_ep{num_epochs}_og_trial.pth" |
|
|
best_save_path = f"/home/siddharth/Sid/ASR/ANC/DEEP_ANC_MODEL_best_bs{batch_size}_lr{learning_rate}_ep{num_epochs}_og_trial.pth" |
|
|
|
|
|
train_model(rank, world_size, model, train_loader, val_loader, num_epochs, learning_rate, save_path, best_save_path, checkpoint_path) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"An error occurred on rank {rank}: {e}") |
|
|
finally: |
|
|
cleanup() |
|
|
|