SEGAN / RemoveNoiseAndConvertIntoPNG.py
91prince's picture
Add scripts for lossless audio-to-PNG packing and audio restoration.
cd6dd93
#!/usr/bin/env python3
"""
ChunkedDenoise-SEGAN-to-PNG.py
- Runs your chunked SEGAN denoiser (mirror-pad last chunk, hann overlap-add, final spectral gating)
- Saves denoised audio (WAV) and packs denoised PCM into a lossless 16-bit PNG (mono).
- Optionally reconstructs WAV from PNG to demonstrate bit-perfect roundtrip.
Edit INPUT_AUDIO / OUTPUT_AUDIO / CHECKPOINT below.
Requires: torch, torchaudio, numpy, Pillow (PIL)
"""
import os
import math
import wave
import torch
import torch.nn.functional as F
import torchaudio
import numpy as np
from PIL import Image
# Import SEGAN components - keep your module with these classes available
from SEGAN import Config, STFTMagTransform, UNetGenerator
# ---------------- USER CONFIG ----------------
INPUT_AUDIO = r"E:\Test Audio Data\AudioSong\romantic-song-tera-roothna-by-ashir-hindi-top-trending-viral-song-231771.wav"
OUTPUT_AUDIO = r"E:\Test Audio Data\Denoised Results\romantic-song-tera-roothna-by-ashir-hindi-top-trending-viral-song-231771.wav"
CHECKPOINT = r"E:\Minor-Project-For-Amity-Patna\Model SEGAN\checkpoints_seagan\seagan_final.pt"
CHUNK_SECONDS = 50.0 # chunk length (seconds)
CHUNK_OVERLAP = 0.5 # fraction overlap (0.5 = 50%)
USE_SPECTRAL_GATE = True # final spectral gating to suppress residual noise
NOISE_FRAC = 0.10 # fraction of lowest-energy frames to estimate noise floor
SUBTRACT_STRENGTH = 1.0 # how strongly to subtract noise floor (1.0 = full)
# PNG packing options
PNG_WIDTH = 2048 # pixels per row for the packed PNG (adjust for memory/shape)
# ---------------------------------------------
class InferConfig(Config):
ckpt_path = CHECKPOINT
device = "cuda" if torch.cuda.is_available() else "cpu"
icfg = InferConfig()
# ---------------- utilities -------------------
def load_mono_resampled(path: str, target_sr: int):
wav, sr = torchaudio.load(path)
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != target_sr:
wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
sr = target_sr
return wav.squeeze(0) # (T,)
def robust_save(path: str, wav_tensor: torch.Tensor, sr: int):
x = wav_tensor.detach().cpu()
if x.dim() == 1:
x = x.unsqueeze(0)
# remove leading unit batch dims
while x.dim() > 2 and x.size(0) == 1:
x = x.squeeze(0)
if x.dim() > 2:
x = torch.squeeze(x)
if x.dim() == 1:
x = x.unsqueeze(0)
x = x.float()
os.makedirs(os.path.dirname(path), exist_ok=True)
torchaudio.save(path, x, sr)
print(f"Saved WAV: {path} (shape={tuple(x.shape)})")
def pad_or_crop_freq(mag: torch.Tensor, target_F: int):
# mag: (1, F_mag, T)
F_mag = mag.shape[1]
if F_mag == target_F:
return mag
if F_mag < target_F:
pad = target_F - F_mag
return F.pad(mag, (0, 0, 0, pad))
else:
return mag[:, :target_F, :]
def mirror_pad_last_chunk(chunk: torch.Tensor, target_len: int):
# chunk: (1,1,L)
L = chunk.shape[-1]
if L >= target_len:
return chunk[:, :, :target_len]
need = target_len - L
frag = chunk[..., -min(L, need):].flip(-1)
out = torch.cat([chunk, frag], dim=-1)
if out.shape[-1] < target_len:
out = F.pad(out, (0, target_len - out.shape[-1]))
return out[:, :, :target_len]
# ---------------- spectral gating (final cleanup) ----------------
def spectral_subtract_and_reconstruct(waveform: torch.Tensor, stft_mod: STFTMagTransform, cfg: InferConfig,
noise_frac=0.1, subtract_strength=1.0, device='cpu'):
"""
waveform: (T,) or (1, T) tensor (CPU)
Estimate noise floor from lowest-energy frames (fraction noise_frac),
subtract (soft) from magnitude spectrogram, reconstruct with original phase via torch.istft.
Returns denoised waveform (1, T_recon) on CPU.
"""
if waveform.dim() == 1:
wav = waveform.unsqueeze(0) # (1, T)
else:
wav = waveform
wav = wav.to(device)
n_fft = cfg.n_fft
hop = cfg.hop_length
win = stft_mod.window.to(device)
spec = torch.stft(wav, n_fft=n_fft, hop_length=hop, win_length=cfg.win_length, window=win, return_complex=True)
mag = torch.abs(spec) # (1, F, T)
phase = torch.angle(spec) # (1, F, T)
frame_energy = mag.pow(2).sum(dim=1).squeeze(0) # (T,)
n_frames = frame_energy.shape[-1]
if n_frames <= 0:
return wav.squeeze(0).cpu()
k = max(1, int(n_frames * noise_frac))
idxs = torch.argsort(frame_energy)[:k]
noise_floor = mag[:, :, idxs].median(dim=-1).values # (1, F)
noise_floor_exp = noise_floor.unsqueeze(-1).repeat(1, 1, mag.shape[-1])
alpha = subtract_strength
mag_sub = mag - alpha * noise_floor_exp
mag_sub = torch.clamp(mag_sub, min=0.0)
real = mag_sub * torch.cos(phase)
imag = mag_sub * torch.sin(phase)
complex_sub = torch.complex(real, imag)
recon = torch.istft(complex_sub, n_fft=n_fft, hop_length=hop, win_length=cfg.win_length, window=win, length=wav.shape[-1])
return recon.squeeze(0).cpu()
# ---------------- core chunked denoiser (improved) ----------------
def denoise_chunked_final(input_path: str, output_path: str, cfg: InferConfig,
chunk_seconds=3.0, overlap=0.5,
use_spectral_gate=True, noise_frac=0.1, subtract_strength=1.0):
device = cfg.device
print("Device:", device)
# load model + stft
print("Loading checkpoint:", cfg.ckpt_path)
ckpt = torch.load(cfg.ckpt_path, map_location=device)
G = UNetGenerator(in_ch=1, out_ch=1).to(device)
G.load_state_dict(ckpt["G_state"])
G.eval()
stft = STFTMagTransform(cfg.n_fft, cfg.hop_length, cfg.win_length).to(device)
window = stft.window.to(device)
# load audio
wav = load_mono_resampled(input_path, cfg.sample_rate) # (T,)
T = wav.shape[0]
sr = cfg.sample_rate
print(f"Input: {T} samples ({T/sr:.2f} s) SR={sr}")
chunk_samples = max(1, int(chunk_seconds * sr))
hop = max(1, int(chunk_samples * (1.0 - overlap)))
print(f"Chunk {chunk_samples} samples, hop {hop} samples")
out_len = T + chunk_samples
out_buffer = torch.zeros(out_len, dtype=torch.float32)
weight_buffer = torch.zeros(out_len, dtype=torch.float32)
synth_win = torch.hann_window(chunk_samples, periodic=True, dtype=torch.float32)
idx = 0
while idx < T:
start = idx
end = min(idx + chunk_samples, T)
chunk = wav[start:end].unsqueeze(0).unsqueeze(0).to(device) # (1,1,L)
orig_len = chunk.shape[-1]
if orig_len < chunk_samples:
chunk = mirror_pad_last_chunk(chunk, chunk_samples).to(device)
with torch.no_grad():
spec = stft(chunk) # (1,1,F_spec,Frames)
fake = G(spec) # (1,1,F_fake,Frames)
mag = torch.expm1(fake.clamp_min(0.0)).squeeze(1) # (1,F_fake,Frames)
chunk_1d = chunk.view(1, -1)
complex_noisy = torch.stft(chunk_1d, n_fft=cfg.n_fft, hop_length=cfg.hop_length,
win_length=cfg.win_length, window=window, return_complex=True)
phase = torch.angle(complex_noisy) # (1,F_phase,Frames_phase)
n_frames_mag = mag.shape[-1]
n_frames_phase = phase.shape[-1]
min_frames = min(n_frames_mag, n_frames_phase)
mag = mag[..., :min_frames]
phase = phase[..., :min_frames]
expected_F = cfg.n_fft // 2 + 1
mag = pad_or_crop_freq(mag, expected_F)
real = mag * torch.cos(phase)
imag = mag * torch.sin(phase)
complex_spec = torch.complex(real, imag).squeeze(0) # (F, frames)
wav_rec = torch.istft(complex_spec.unsqueeze(0).to(device),
n_fft=cfg.n_fft, hop_length=cfg.hop_length,
win_length=cfg.win_length, window=window,
length=chunk_samples).squeeze(0).cpu()
if wav_rec.shape[-1] < chunk_samples:
wav_rec = F.pad(wav_rec, (0, chunk_samples - wav_rec.shape[-1]))
elif wav_rec.shape[-1] > chunk_samples:
wav_rec = wav_rec[:chunk_samples]
win = synth_win.clone().cpu()
wav_rec_win = wav_rec * win
write_start = start
write_end = start + chunk_samples
out_buffer[write_start:write_end] += wav_rec_win
weight_buffer[write_start:write_end] += win
idx += hop
nonzero = weight_buffer > 1e-8
out_buffer[nonzero] = out_buffer[nonzero] / weight_buffer[nonzero]
denoised = out_buffer[:T].contiguous()
if use_spectral_gate:
print("Applying final spectral gating...")
denoised = spectral_subtract_and_reconstruct(denoised.unsqueeze(0), stft, cfg,
noise_frac=noise_frac, subtract_strength=subtract_strength,
device=cfg.device)
denoised = torch.clamp(denoised, -0.999, 0.999)
robust_save(output_path, denoised, sr)
# After saving WAV, pack denoised audio into lossless PNG
png_path = os.path.splitext(output_path)[0] + "_packed.png"
save_audio_as_png_lossless(denoised, png_path, width=PNG_WIDTH)
print("Packed denoised audio into PNG:", png_path)
# Optionally: reconstruct WAV from PNG to verify roundtrip correctness
recon_wav = os.path.splitext(output_path)[0] + "_reconstructed_from_png.wav"
restored = load_audio_from_png_lossless(png_path, original_length=denoised.shape[-1])
write_wav_from_tensor(restored, recon_wav, sr)
print("Reconstructed WAV from PNG:", recon_wav)
return output_path, png_path, recon_wav
# === Lossless audio <-> PNG packing (bit-perfect) ===
def audio_tensor_to_int16_array(wav_tensor: torch.Tensor):
"""wav_tensor: 1D float tensor (-1..1) or (1, T). Returns numpy int16 1D array."""
if isinstance(wav_tensor, torch.Tensor):
x = wav_tensor.detach().cpu().numpy()
else:
x = np.asarray(wav_tensor)
if x.ndim == 2 and x.shape[0] == 1:
x = x[0]
x = np.clip(x, -1.0, 1.0)
int16 = (x * 32767.0).astype(np.int16)
return int16
def int16_array_to_audio_tensor(int16_arr: np.ndarray):
"""Return torch.FloatTensor mono in range -1..1"""
arr = np.asarray(int16_arr, dtype=np.int16)
float32 = (arr.astype(np.float32) / 32767.0)
return torch.from_numpy(float32)
def save_audio_as_png_lossless(wav_tensor: torch.Tensor, png_path: str, width: int = 2048):
"""
Pack int16 PCM samples into a single-channel 16-bit PNG.
width = number of pixels per row. height computed automatically.
"""
samples = audio_tensor_to_int16_array(wav_tensor)
N = samples.shape[0]
height = math.ceil(N / width)
total = width * height
pad = total - N
padded = np.pad(samples, (0, pad), mode='constant', constant_values=0).astype(np.int16)
arr = padded.reshape((height, width))
# reinterpret signed int16 bits as uint16 so PIL can save without changing bits
uint16_view = arr.view(np.uint16)
im = Image.fromarray(uint16_view, mode='I;16')
os.makedirs(os.path.dirname(png_path), exist_ok=True)
im.save(png_path, format='PNG')
print(f"Saved lossless audio PNG: {png_path} (samples={N}, width={width}, height={height})")
return png_path
def load_audio_from_png_lossless(png_path: str, original_length: int = None):
"""
Read 16-bit PNG saved by save_audio_as_png_lossless and return torch.FloatTensor mono (-1..1).
If original_length provided, crop the extra padding.
"""
im = Image.open(png_path)
arr_uint16 = np.array(im, dtype=np.uint16)
int16_arr = arr_uint16.view(np.int16).reshape(-1)
if original_length is not None:
int16_arr = int16_arr[:original_length]
float_tensor = int16_array_to_audio_tensor(int16_arr)
return float_tensor # 1D torch tensor
def write_wav_from_tensor(tensor: torch.Tensor, out_wav_path: str, sr: int):
# tensor: 1D float -1..1
x = tensor.detach().cpu().numpy()
int16 = (np.clip(x, -1.0, 1.0) * 32767.0).astype(np.int16)
os.makedirs(os.path.dirname(out_wav_path), exist_ok=True)
with wave.open(out_wav_path, 'wb') as wf:
wf.setnchannels(1)
wf.setsampwidth(2) # 16-bit
wf.setframerate(sr)
wf.writeframes(int16.tobytes())
print(f"WAV written (lossless restore): {out_wav_path} (samples={int16.size}, sr={sr})")
return out_wav_path
# ---------------- run ----------------
if __name__ == "__main__":
print("Running final chunked denoiser -> PNG (lossless) ...")
out_wav, out_png, out_recon = denoise_chunked_final(INPUT_AUDIO, OUTPUT_AUDIO, icfg,
chunk_seconds=CHUNK_SECONDS,
overlap=CHUNK_OVERLAP,
use_spectral_gate=USE_SPECTRAL_GATE,
noise_frac=NOISE_FRAC,
subtract_strength=SUBTRACT_STRENGTH)
print("Done.")
print("Denoised WAV:", out_wav)
print("Packed PNG:", out_png)
print("Reconstructed WAV from PNG (verification):", out_recon)