SEGAN / SEGAN.py
91prince's picture
Add SEAGAN model code, pipeline, and large checkpoint file
7eadad0
#!/usr/bin/env python3
"""
SEAGAN-style Speech Enhancement (Noise Removal) Training Script
- Generator: U-Net on log-magnitude spectrograms
- Discriminator: PatchGAN-style conditional (noisy + clean/enhanced)
- Loss: L1 (reconstruction) + adversarial (LSGAN)
Requirements:
pip install torch torchaudio numpy
"""
import os
import glob
import random
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
# ==========================
# CONFIG
# ==========================
class Config:
# Paths (CHANGE THESE TO YOUR FOLDERS)
noisy_dir = r"E:\Minor-Project-For-Amity-Patna\Models\Audio Data\Noisy Data" # noisy wavs
clean_dir = r"E:\Minor-Project-For-Amity-Patna\Models\Audio Data\Noiseless Data" # clean wavs
save_dir = r"E:\Minor-Project-For-Amity-Patna\Model SEGAN\checkpoints_seagan"
# Audio
sample_rate = 16000
segment_seconds = 1.0 # train on 1-second chunks
mono = True
# STFT / Spectrogram
n_fft = 512
hop_length = 128
win_length = 512
# Training
batch_size = 8
num_workers = 2
num_epochs = 50
lr_g = 2e-4
lr_d = 2e-4
beta1 = 0.5
beta2 = 0.999
lambda_l1 = 100.0 # weight for L1 loss vs GAN loss (like pix2pix)
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg = Config()
# ==========================
# DATASET
# ==========================
def list_wav_pairs(noisy_dir: str, clean_dir: str) -> List[Tuple[str, str]]:
noisy_files = sorted(glob.glob(os.path.join(noisy_dir, "*.wav")))
pairs = []
for nf in noisy_files:
name = os.path.basename(nf)
cf = os.path.join(clean_dir, name)
if os.path.exists(cf):
pairs.append((nf, cf))
return pairs
class SEAGANDataset(Dataset):
def __init__(
self,
noisy_dir: str,
clean_dir: str,
sample_rate: int = 16000,
segment_seconds: float = 1.0,
):
self.sample_rate = sample_rate
self.segment_samples = int(segment_seconds * sample_rate)
self.pairs = list_wav_pairs(noisy_dir, clean_dir)
if len(self.pairs) == 0:
raise RuntimeError("No paired .wav files found! Check your folders & names.")
self.resampler_cache = {}
def __len__(self):
return len(self.pairs)
def _get_resampler(self, orig_sr: int):
if orig_sr == self.sample_rate:
return None
if orig_sr not in self.resampler_cache:
self.resampler_cache[orig_sr] = torchaudio.transforms.Resample(
orig_freq=orig_sr, new_freq=self.sample_rate
)
return self.resampler_cache[orig_sr]
def _load_audio(self, path: str) -> torch.Tensor:
wav, sr = torchaudio.load(path) # shape: (channels, samples)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True) # mono
resampler = self._get_resampler(sr)
if resampler is not None:
wav = resampler(wav)
return wav # (1, samples)
def _aligned_random_crop(self, noisy: torch.Tensor, clean: torch.Tensor):
"""
Crop noisy and clean with the same start index for alignment.
noisy, clean: (1, T)
"""
T = min(noisy.shape[1], clean.shape[1])
noisy = noisy[:, :T]
clean = clean[:, :T]
if T <= self.segment_samples:
pad = self.segment_samples - T
noisy = torch.nn.functional.pad(noisy, (0, pad))
clean = torch.nn.functional.pad(clean, (0, pad))
return noisy, clean
else:
start = random.randint(0, T - self.segment_samples)
end = start + self.segment_samples
return noisy[:, start:end], clean[:, start:end]
def __getitem__(self, idx: int):
noisy_path, clean_path = self.pairs[idx]
noisy = self._load_audio(noisy_path)
clean = self._load_audio(clean_path)
noisy, clean = self._aligned_random_crop(noisy, clean)
return noisy, clean
# ==========================
# SPECTROGRAM HELPERS
# ==========================
class STFTMagTransform(nn.Module):
"""
Convert waveform -> log-magnitude spectrogram
"""
def __init__(self, n_fft, hop_length, win_length):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
# register window so it moves with .to(device)
self.register_buffer("window", torch.hann_window(win_length))
def forward(self, wav: torch.Tensor) -> torch.Tensor:
"""
wav: (B, 1, T)
return: (B, 1, F, T_spec)
"""
B, C, T = wav.shape
wav = wav.view(B * C, T)
spec = torch.stft(
wav,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
return_complex=True,
)
mag = torch.abs(spec) # (B*C, F, T_spec)
log_mag = torch.log1p(mag) # log(1 + mag)
log_mag = log_mag.view(B, C, log_mag.shape[1], log_mag.shape[2])
return log_mag
# ==========================
# SIZE MATCH HELPER
# ==========================
def match_size(a: torch.Tensor, b: torch.Tensor):
"""
Crop a and b to have the same (H, W). Keeps the top-left region.
a, b: (..., H, W)
returns: (a_crop, b_crop)
"""
Ha, Wa = a.shape[-2], a.shape[-1]
Hb, Wb = b.shape[-2], b.shape[-1]
H = min(Ha, Hb)
W = min(Wa, Wb)
a_c = a[..., :H, :W]
b_c = b[..., :H, :W]
return a_c, b_c
# ==========================
# GENERATOR (U-NET)
# ==========================
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, down=True, use_bn=True):
super().__init__()
if down:
layers = [
nn.Conv2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
]
else:
layers = [
nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
]
if use_bn:
layers.insert(1, nn.BatchNorm2d(out_ch))
self.block = nn.Sequential(*layers)
def forward(self, x):
return self.block(x)
class UNetGenerator(nn.Module):
"""
U-Net operating on (B, 1, F, T) log-magnitude spectrograms
"""
def __init__(self, in_ch=1, out_ch=1, base_ch=64):
super().__init__()
# Encoder
self.down1 = ConvBlock(in_ch, base_ch, down=True, use_bn=False) # (64)
self.down2 = ConvBlock(base_ch, base_ch * 2)
self.down3 = ConvBlock(base_ch * 2, base_ch * 4)
self.down4 = ConvBlock(base_ch * 4, base_ch * 8)
self.down5 = ConvBlock(base_ch * 8, base_ch * 8)
# Bottleneck
self.bottleneck = nn.Sequential(
nn.Conv2d(base_ch * 8, base_ch * 8, kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
)
# Decoder
self.up1 = ConvBlock(base_ch * 8, base_ch * 8, down=False)
self.up2 = ConvBlock(base_ch * 8 * 2, base_ch * 8, down=False)
self.up3 = ConvBlock(base_ch * 8 * 2, base_ch * 4, down=False)
self.up4 = ConvBlock(base_ch * 4 * 2, base_ch * 2, down=False)
self.up5 = ConvBlock(base_ch * 2 * 2, base_ch, down=False)
self.final = nn.ConvTranspose2d(
base_ch * 2, out_ch, kernel_size=4, stride=2, padding=1
)
# Output non-negative log-magnitude
self.out_act = nn.ReLU()
def _crop_to(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
"""
Center-crop src to have the same H, W as tgt.
src: (B, C, Hs, Ws)
tgt: (B, C, Ht, Wt) (only Ht, Wt are used)
"""
_, _, Hs, Ws = src.shape
_, _, Ht, Wt = tgt.shape
if Hs == Ht and Ws == Wt:
return src
start_h = max((Hs - Ht) // 2, 0)
start_w = max((Ws - Wt) // 2, 0)
end_h = start_h + Ht
end_w = start_w + Wt
return src[:, :, start_h:end_h, start_w:end_w]
def forward(self, x):
# encoder
d1 = self.down1(x) # B,64
d2 = self.down2(d1) # B,128
d3 = self.down3(d2) # B,256
d4 = self.down4(d3) # B,512
d5 = self.down5(d4) # B,512
bott = self.bottleneck(d5)
# decoder with crops + skips
u1 = self.up1(bott)
d5_c = self._crop_to(d5, u1)
u1 = torch.cat([u1, d5_c], dim=1)
u2 = self.up2(u1)
d4_c = self._crop_to(d4, u2)
u2 = torch.cat([u2, d4_c], dim=1)
u3 = self.up3(u2)
d3_c = self._crop_to(d3, u3)
u3 = torch.cat([u3, d3_c], dim=1)
u4 = self.up4(u3)
d2_c = self._crop_to(d2, u4)
u4 = torch.cat([u4, d2_c], dim=1)
u5 = self.up5(u4)
d1_c = self._crop_to(d1, u5)
u5 = torch.cat([u5, d1_c], dim=1)
out = self.final(u5)
out = self.out_act(out) # non-negative log-magnitude
return out
# ==========================
# DISCRIMINATOR (PatchGAN)
# ==========================
class PatchDiscriminator(nn.Module):
"""
Conditional discriminator: input = concat(noisy_spec, clean_or_fake_spec)
"""
def __init__(self, in_ch=2, base_ch=64):
super().__init__()
# no batchnorm in first layer
self.model = nn.Sequential(
nn.Conv2d(in_ch, base_ch, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_ch, base_ch * 2, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(base_ch * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_ch * 2, base_ch * 4, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(base_ch * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_ch * 4, base_ch * 8, kernel_size=4, stride=1, padding=1),
nn.BatchNorm2d(base_ch * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_ch * 8, 1, kernel_size=4, stride=1, padding=1),
# no activation -> LSGAN
)
def forward(self, x):
return self.model(x) # (B, 1, H', W')
# ==========================
# TRAINING
# ==========================
def save_checkpoint(epoch, G, D, opt_g, opt_d, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(
{
"epoch": epoch,
"G_state": G.state_dict(),
"D_state": D.state_dict(),
"opt_g_state": opt_g.state_dict(),
"opt_d_state": opt_d.state_dict(),
},
path,
)
print(f"Saved checkpoint: {path}")
def train():
device = cfg.device
print(f"Using device: {device}")
dataset = SEAGANDataset(
cfg.noisy_dir, cfg.clean_dir, cfg.sample_rate, cfg.segment_seconds
)
loader = DataLoader(
dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
drop_last=True,
)
stft_transform = STFTMagTransform(
cfg.n_fft, cfg.hop_length, cfg.win_length
).to(device)
G = UNetGenerator(in_ch=1, out_ch=1).to(device)
D = PatchDiscriminator(in_ch=2).to(device)
# LSGAN loss
criterion_gan = nn.MSELoss()
criterion_l1 = nn.L1Loss()
opt_g = optim.Adam(G.parameters(), lr=cfg.lr_g, betas=(cfg.beta1, cfg.beta2))
opt_d = optim.Adam(D.parameters(), lr=cfg.lr_d, betas=(cfg.beta1, cfg.beta2))
for epoch in range(1, cfg.num_epochs + 1):
G.train()
D.train()
running_g_loss = 0.0
running_d_loss = 0.0
for i, (noisy_wav, clean_wav) in enumerate(loader):
noisy_wav = noisy_wav.to(device) # (B,1,T)
clean_wav = clean_wav.to(device) # (B,1,T)
# -------------------------
# Waveform -> Spectrogram
# -------------------------
noisy_spec = stft_transform(noisy_wav) # (B,1,F,T_spec)
clean_spec = stft_transform(clean_wav) # (B,1,F,T_spec)
# Ensure same size for real pair
noisy_spec, clean_spec = match_size(noisy_spec, clean_spec)
# =========================
# Train Discriminator
# =========================
opt_d.zero_grad()
# Real pair: (noisy, clean)
real_input = torch.cat([noisy_spec, clean_spec], dim=1)
pred_real = D(real_input)
target_real = torch.ones_like(pred_real)
loss_d_real = criterion_gan(pred_real, target_real)
# Fake pair: (noisy, enhanced)
with torch.no_grad():
fake_spec = G(noisy_spec)
# match noisy and fake sizes
noisy_for_fake_d, fake_spec_d = match_size(noisy_spec, fake_spec)
fake_input = torch.cat([noisy_for_fake_d, fake_spec_d], dim=1)
pred_fake = D(fake_input)
target_fake = torch.zeros_like(pred_fake)
loss_d_fake = criterion_gan(pred_fake, target_fake)
loss_d = 0.5 * (loss_d_real + loss_d_fake)
loss_d.backward()
opt_d.step()
# =========================
# Train Generator
# =========================
opt_g.zero_grad()
fake_spec = G(noisy_spec)
# GAN loss (want D(noisy, fake) = 1)
noisy_for_fake_g, fake_spec_g = match_size(noisy_spec, fake_spec)
fake_input_g = torch.cat([noisy_for_fake_g, fake_spec_g], dim=1)
pred_fake_for_g = D(fake_input_g)
target_real_for_g = torch.ones_like(pred_fake_for_g)
loss_g_gan = criterion_gan(pred_fake_for_g, target_real_for_g)
# L1 reconstruction loss (match fake & clean sizes)
fake_l1, clean_l1 = match_size(fake_spec, clean_spec)
loss_g_l1 = criterion_l1(fake_l1, clean_l1) * cfg.lambda_l1
loss_g = loss_g_gan + loss_g_l1
loss_g.backward()
opt_g.step()
running_d_loss += loss_d.item()
running_g_loss += loss_g.item()
if (i + 1) % 20 == 0:
print(
f"Epoch [{epoch}/{cfg.num_epochs}] "
f"Step [{i+1}/{len(loader)}] "
f"D Loss: {loss_d.item():.4f} "
f"G Loss: {loss_g.item():.4f} "
f"(GAN: {loss_g_gan.item():.4f}, L1: {loss_g_l1.item():.4f})"
)
avg_d = running_d_loss / len(loader)
avg_g = running_g_loss / len(loader)
print(
f"==> Epoch {epoch} finished | "
f"Avg D Loss: {avg_d:.4f} | Avg G Loss: {avg_g:.4f}"
)
# save checkpoint every few epochs
if epoch % 5 == 0:
ckpt_path = os.path.join(cfg.save_dir, f"seagan_epoch_{epoch}.pt")
save_checkpoint(epoch, G, D, opt_g, opt_d, ckpt_path)
# final save
ckpt_path = os.path.join(cfg.save_dir, f"seagan_final.pt")
save_checkpoint(cfg.num_epochs, G, D, opt_g, opt_d, ckpt_path)
if __name__ == "__main__":
train()