| | |
| | """ |
| | SEAGAN-style Speech Enhancement (Noise Removal) Training Script |
| | |
| | - Generator: U-Net on log-magnitude spectrograms |
| | - Discriminator: PatchGAN-style conditional (noisy + clean/enhanced) |
| | - Loss: L1 (reconstruction) + adversarial (LSGAN) |
| | |
| | Requirements: |
| | pip install torch torchaudio numpy |
| | """ |
| |
|
| | import os |
| | import glob |
| | import random |
| | from typing import List, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import Dataset, DataLoader |
| |
|
| | import torchaudio |
| |
|
| | |
| | |
| | |
| |
|
| | class Config: |
| | |
| | noisy_dir = r"E:\Minor-Project-For-Amity-Patna\Models\Audio Data\Noisy Data" |
| | clean_dir = r"E:\Minor-Project-For-Amity-Patna\Models\Audio Data\Noiseless Data" |
| | save_dir = r"E:\Minor-Project-For-Amity-Patna\Model SEGAN\checkpoints_seagan" |
| |
|
| | |
| | sample_rate = 16000 |
| | segment_seconds = 1.0 |
| | mono = True |
| |
|
| | |
| | n_fft = 512 |
| | hop_length = 128 |
| | win_length = 512 |
| |
|
| | |
| | batch_size = 8 |
| | num_workers = 2 |
| | num_epochs = 50 |
| | lr_g = 2e-4 |
| | lr_d = 2e-4 |
| | beta1 = 0.5 |
| | beta2 = 0.999 |
| |
|
| | lambda_l1 = 100.0 |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | cfg = Config() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def list_wav_pairs(noisy_dir: str, clean_dir: str) -> List[Tuple[str, str]]: |
| | noisy_files = sorted(glob.glob(os.path.join(noisy_dir, "*.wav"))) |
| | pairs = [] |
| | for nf in noisy_files: |
| | name = os.path.basename(nf) |
| | cf = os.path.join(clean_dir, name) |
| | if os.path.exists(cf): |
| | pairs.append((nf, cf)) |
| | return pairs |
| |
|
| |
|
| | class SEAGANDataset(Dataset): |
| | def __init__( |
| | self, |
| | noisy_dir: str, |
| | clean_dir: str, |
| | sample_rate: int = 16000, |
| | segment_seconds: float = 1.0, |
| | ): |
| | self.sample_rate = sample_rate |
| | self.segment_samples = int(segment_seconds * sample_rate) |
| |
|
| | self.pairs = list_wav_pairs(noisy_dir, clean_dir) |
| | if len(self.pairs) == 0: |
| | raise RuntimeError("No paired .wav files found! Check your folders & names.") |
| |
|
| | self.resampler_cache = {} |
| |
|
| | def __len__(self): |
| | return len(self.pairs) |
| |
|
| | def _get_resampler(self, orig_sr: int): |
| | if orig_sr == self.sample_rate: |
| | return None |
| | if orig_sr not in self.resampler_cache: |
| | self.resampler_cache[orig_sr] = torchaudio.transforms.Resample( |
| | orig_freq=orig_sr, new_freq=self.sample_rate |
| | ) |
| | return self.resampler_cache[orig_sr] |
| |
|
| | def _load_audio(self, path: str) -> torch.Tensor: |
| | wav, sr = torchaudio.load(path) |
| | if wav.shape[0] > 1: |
| | wav = wav.mean(dim=0, keepdim=True) |
| | resampler = self._get_resampler(sr) |
| | if resampler is not None: |
| | wav = resampler(wav) |
| | return wav |
| |
|
| | def _aligned_random_crop(self, noisy: torch.Tensor, clean: torch.Tensor): |
| | """ |
| | Crop noisy and clean with the same start index for alignment. |
| | noisy, clean: (1, T) |
| | """ |
| | T = min(noisy.shape[1], clean.shape[1]) |
| | noisy = noisy[:, :T] |
| | clean = clean[:, :T] |
| |
|
| | if T <= self.segment_samples: |
| | pad = self.segment_samples - T |
| | noisy = torch.nn.functional.pad(noisy, (0, pad)) |
| | clean = torch.nn.functional.pad(clean, (0, pad)) |
| | return noisy, clean |
| | else: |
| | start = random.randint(0, T - self.segment_samples) |
| | end = start + self.segment_samples |
| | return noisy[:, start:end], clean[:, start:end] |
| |
|
| | def __getitem__(self, idx: int): |
| | noisy_path, clean_path = self.pairs[idx] |
| |
|
| | noisy = self._load_audio(noisy_path) |
| | clean = self._load_audio(clean_path) |
| |
|
| | noisy, clean = self._aligned_random_crop(noisy, clean) |
| |
|
| | return noisy, clean |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class STFTMagTransform(nn.Module): |
| | """ |
| | Convert waveform -> log-magnitude spectrogram |
| | """ |
| |
|
| | def __init__(self, n_fft, hop_length, win_length): |
| | super().__init__() |
| | self.n_fft = n_fft |
| | self.hop_length = hop_length |
| | self.win_length = win_length |
| |
|
| | |
| | self.register_buffer("window", torch.hann_window(win_length)) |
| |
|
| | def forward(self, wav: torch.Tensor) -> torch.Tensor: |
| | """ |
| | wav: (B, 1, T) |
| | return: (B, 1, F, T_spec) |
| | """ |
| | B, C, T = wav.shape |
| |
|
| | wav = wav.view(B * C, T) |
| | spec = torch.stft( |
| | wav, |
| | n_fft=self.n_fft, |
| | hop_length=self.hop_length, |
| | win_length=self.win_length, |
| | window=self.window, |
| | return_complex=True, |
| | ) |
| | mag = torch.abs(spec) |
| | log_mag = torch.log1p(mag) |
| | log_mag = log_mag.view(B, C, log_mag.shape[1], log_mag.shape[2]) |
| | return log_mag |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def match_size(a: torch.Tensor, b: torch.Tensor): |
| | """ |
| | Crop a and b to have the same (H, W). Keeps the top-left region. |
| | a, b: (..., H, W) |
| | returns: (a_crop, b_crop) |
| | """ |
| | Ha, Wa = a.shape[-2], a.shape[-1] |
| | Hb, Wb = b.shape[-2], b.shape[-1] |
| | H = min(Ha, Hb) |
| | W = min(Wa, Wb) |
| | a_c = a[..., :H, :W] |
| | b_c = b[..., :H, :W] |
| | return a_c, b_c |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ConvBlock(nn.Module): |
| | def __init__(self, in_ch, out_ch, down=True, use_bn=True): |
| | super().__init__() |
| | if down: |
| | layers = [ |
| | nn.Conv2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | ] |
| | else: |
| | layers = [ |
| | nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1), |
| | nn.ReLU(inplace=True), |
| | ] |
| |
|
| | if use_bn: |
| | layers.insert(1, nn.BatchNorm2d(out_ch)) |
| |
|
| | self.block = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | return self.block(x) |
| |
|
| |
|
| | class UNetGenerator(nn.Module): |
| | """ |
| | U-Net operating on (B, 1, F, T) log-magnitude spectrograms |
| | """ |
| | def __init__(self, in_ch=1, out_ch=1, base_ch=64): |
| | super().__init__() |
| |
|
| | |
| | self.down1 = ConvBlock(in_ch, base_ch, down=True, use_bn=False) |
| | self.down2 = ConvBlock(base_ch, base_ch * 2) |
| | self.down3 = ConvBlock(base_ch * 2, base_ch * 4) |
| | self.down4 = ConvBlock(base_ch * 4, base_ch * 8) |
| | self.down5 = ConvBlock(base_ch * 8, base_ch * 8) |
| |
|
| | |
| | self.bottleneck = nn.Sequential( |
| | nn.Conv2d(base_ch * 8, base_ch * 8, kernel_size=4, stride=2, padding=1), |
| | nn.ReLU(inplace=True), |
| | ) |
| |
|
| | |
| | self.up1 = ConvBlock(base_ch * 8, base_ch * 8, down=False) |
| | self.up2 = ConvBlock(base_ch * 8 * 2, base_ch * 8, down=False) |
| | self.up3 = ConvBlock(base_ch * 8 * 2, base_ch * 4, down=False) |
| | self.up4 = ConvBlock(base_ch * 4 * 2, base_ch * 2, down=False) |
| | self.up5 = ConvBlock(base_ch * 2 * 2, base_ch, down=False) |
| |
|
| | self.final = nn.ConvTranspose2d( |
| | base_ch * 2, out_ch, kernel_size=4, stride=2, padding=1 |
| | ) |
| | |
| | self.out_act = nn.ReLU() |
| |
|
| | def _crop_to(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Center-crop src to have the same H, W as tgt. |
| | src: (B, C, Hs, Ws) |
| | tgt: (B, C, Ht, Wt) (only Ht, Wt are used) |
| | """ |
| | _, _, Hs, Ws = src.shape |
| | _, _, Ht, Wt = tgt.shape |
| |
|
| | if Hs == Ht and Ws == Wt: |
| | return src |
| |
|
| | start_h = max((Hs - Ht) // 2, 0) |
| | start_w = max((Ws - Wt) // 2, 0) |
| | end_h = start_h + Ht |
| | end_w = start_w + Wt |
| |
|
| | return src[:, :, start_h:end_h, start_w:end_w] |
| |
|
| | def forward(self, x): |
| | |
| | d1 = self.down1(x) |
| | d2 = self.down2(d1) |
| | d3 = self.down3(d2) |
| | d4 = self.down4(d3) |
| | d5 = self.down5(d4) |
| |
|
| | bott = self.bottleneck(d5) |
| |
|
| | |
| | u1 = self.up1(bott) |
| | d5_c = self._crop_to(d5, u1) |
| | u1 = torch.cat([u1, d5_c], dim=1) |
| |
|
| | u2 = self.up2(u1) |
| | d4_c = self._crop_to(d4, u2) |
| | u2 = torch.cat([u2, d4_c], dim=1) |
| |
|
| | u3 = self.up3(u2) |
| | d3_c = self._crop_to(d3, u3) |
| | u3 = torch.cat([u3, d3_c], dim=1) |
| |
|
| | u4 = self.up4(u3) |
| | d2_c = self._crop_to(d2, u4) |
| | u4 = torch.cat([u4, d2_c], dim=1) |
| |
|
| | u5 = self.up5(u4) |
| | d1_c = self._crop_to(d1, u5) |
| | u5 = torch.cat([u5, d1_c], dim=1) |
| |
|
| | out = self.final(u5) |
| | out = self.out_act(out) |
| | return out |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class PatchDiscriminator(nn.Module): |
| | """ |
| | Conditional discriminator: input = concat(noisy_spec, clean_or_fake_spec) |
| | """ |
| | def __init__(self, in_ch=2, base_ch=64): |
| | super().__init__() |
| | |
| | self.model = nn.Sequential( |
| | nn.Conv2d(in_ch, base_ch, kernel_size=4, stride=2, padding=1), |
| | nn.LeakyReLU(0.2, inplace=True), |
| |
|
| | nn.Conv2d(base_ch, base_ch * 2, kernel_size=4, stride=2, padding=1), |
| | nn.BatchNorm2d(base_ch * 2), |
| | nn.LeakyReLU(0.2, inplace=True), |
| |
|
| | nn.Conv2d(base_ch * 2, base_ch * 4, kernel_size=4, stride=2, padding=1), |
| | nn.BatchNorm2d(base_ch * 4), |
| | nn.LeakyReLU(0.2, inplace=True), |
| |
|
| | nn.Conv2d(base_ch * 4, base_ch * 8, kernel_size=4, stride=1, padding=1), |
| | nn.BatchNorm2d(base_ch * 8), |
| | nn.LeakyReLU(0.2, inplace=True), |
| |
|
| | nn.Conv2d(base_ch * 8, 1, kernel_size=4, stride=1, padding=1), |
| | |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def save_checkpoint(epoch, G, D, opt_g, opt_d, path): |
| | os.makedirs(os.path.dirname(path), exist_ok=True) |
| | torch.save( |
| | { |
| | "epoch": epoch, |
| | "G_state": G.state_dict(), |
| | "D_state": D.state_dict(), |
| | "opt_g_state": opt_g.state_dict(), |
| | "opt_d_state": opt_d.state_dict(), |
| | }, |
| | path, |
| | ) |
| | print(f"Saved checkpoint: {path}") |
| |
|
| |
|
| | def train(): |
| | device = cfg.device |
| | print(f"Using device: {device}") |
| |
|
| | dataset = SEAGANDataset( |
| | cfg.noisy_dir, cfg.clean_dir, cfg.sample_rate, cfg.segment_seconds |
| | ) |
| | loader = DataLoader( |
| | dataset, |
| | batch_size=cfg.batch_size, |
| | shuffle=True, |
| | num_workers=cfg.num_workers, |
| | drop_last=True, |
| | ) |
| |
|
| | stft_transform = STFTMagTransform( |
| | cfg.n_fft, cfg.hop_length, cfg.win_length |
| | ).to(device) |
| |
|
| | G = UNetGenerator(in_ch=1, out_ch=1).to(device) |
| | D = PatchDiscriminator(in_ch=2).to(device) |
| |
|
| | |
| | criterion_gan = nn.MSELoss() |
| | criterion_l1 = nn.L1Loss() |
| |
|
| | opt_g = optim.Adam(G.parameters(), lr=cfg.lr_g, betas=(cfg.beta1, cfg.beta2)) |
| | opt_d = optim.Adam(D.parameters(), lr=cfg.lr_d, betas=(cfg.beta1, cfg.beta2)) |
| |
|
| | for epoch in range(1, cfg.num_epochs + 1): |
| | G.train() |
| | D.train() |
| |
|
| | running_g_loss = 0.0 |
| | running_d_loss = 0.0 |
| |
|
| | for i, (noisy_wav, clean_wav) in enumerate(loader): |
| | noisy_wav = noisy_wav.to(device) |
| | clean_wav = clean_wav.to(device) |
| |
|
| | |
| | |
| | |
| | noisy_spec = stft_transform(noisy_wav) |
| | clean_spec = stft_transform(clean_wav) |
| |
|
| | |
| | noisy_spec, clean_spec = match_size(noisy_spec, clean_spec) |
| |
|
| | |
| | |
| | |
| | opt_d.zero_grad() |
| |
|
| | |
| | real_input = torch.cat([noisy_spec, clean_spec], dim=1) |
| | pred_real = D(real_input) |
| | target_real = torch.ones_like(pred_real) |
| | loss_d_real = criterion_gan(pred_real, target_real) |
| |
|
| | |
| | with torch.no_grad(): |
| | fake_spec = G(noisy_spec) |
| | |
| | noisy_for_fake_d, fake_spec_d = match_size(noisy_spec, fake_spec) |
| | fake_input = torch.cat([noisy_for_fake_d, fake_spec_d], dim=1) |
| | pred_fake = D(fake_input) |
| | target_fake = torch.zeros_like(pred_fake) |
| | loss_d_fake = criterion_gan(pred_fake, target_fake) |
| |
|
| | loss_d = 0.5 * (loss_d_real + loss_d_fake) |
| | loss_d.backward() |
| | opt_d.step() |
| |
|
| | |
| | |
| | |
| | opt_g.zero_grad() |
| |
|
| | fake_spec = G(noisy_spec) |
| |
|
| | |
| | noisy_for_fake_g, fake_spec_g = match_size(noisy_spec, fake_spec) |
| | fake_input_g = torch.cat([noisy_for_fake_g, fake_spec_g], dim=1) |
| | pred_fake_for_g = D(fake_input_g) |
| | target_real_for_g = torch.ones_like(pred_fake_for_g) |
| | loss_g_gan = criterion_gan(pred_fake_for_g, target_real_for_g) |
| |
|
| | |
| | fake_l1, clean_l1 = match_size(fake_spec, clean_spec) |
| | loss_g_l1 = criterion_l1(fake_l1, clean_l1) * cfg.lambda_l1 |
| |
|
| | loss_g = loss_g_gan + loss_g_l1 |
| | loss_g.backward() |
| | opt_g.step() |
| |
|
| | running_d_loss += loss_d.item() |
| | running_g_loss += loss_g.item() |
| |
|
| | if (i + 1) % 20 == 0: |
| | print( |
| | f"Epoch [{epoch}/{cfg.num_epochs}] " |
| | f"Step [{i+1}/{len(loader)}] " |
| | f"D Loss: {loss_d.item():.4f} " |
| | f"G Loss: {loss_g.item():.4f} " |
| | f"(GAN: {loss_g_gan.item():.4f}, L1: {loss_g_l1.item():.4f})" |
| | ) |
| |
|
| | avg_d = running_d_loss / len(loader) |
| | avg_g = running_g_loss / len(loader) |
| | print( |
| | f"==> Epoch {epoch} finished | " |
| | f"Avg D Loss: {avg_d:.4f} | Avg G Loss: {avg_g:.4f}" |
| | ) |
| |
|
| | |
| | if epoch % 5 == 0: |
| | ckpt_path = os.path.join(cfg.save_dir, f"seagan_epoch_{epoch}.pt") |
| | save_checkpoint(epoch, G, D, opt_g, opt_d, ckpt_path) |
| |
|
| | |
| | ckpt_path = os.path.join(cfg.save_dir, f"seagan_final.pt") |
| | save_checkpoint(cfg.num_epochs, G, D, opt_g, opt_d, ckpt_path) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | train() |
| |
|