File size: 13,336 Bytes
cd6dd93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
#!/usr/bin/env python3
"""
ChunkedDenoise-SEGAN-to-PNG.py

- Runs your chunked SEGAN denoiser (mirror-pad last chunk, hann overlap-add, final spectral gating)
- Saves denoised audio (WAV) and packs denoised PCM into a lossless 16-bit PNG (mono).
- Optionally reconstructs WAV from PNG to demonstrate bit-perfect roundtrip.

Edit INPUT_AUDIO / OUTPUT_AUDIO / CHECKPOINT below.
Requires: torch, torchaudio, numpy, Pillow (PIL)
"""

import os
import math
import wave
import torch
import torch.nn.functional as F
import torchaudio
import numpy as np
from PIL import Image

# Import SEGAN components - keep your module with these classes available
from SEGAN import Config, STFTMagTransform, UNetGenerator

# ---------------- USER CONFIG ----------------
INPUT_AUDIO  = r"E:\Test Audio Data\AudioSong\romantic-song-tera-roothna-by-ashir-hindi-top-trending-viral-song-231771.wav"
OUTPUT_AUDIO = r"E:\Test Audio Data\Denoised Results\romantic-song-tera-roothna-by-ashir-hindi-top-trending-viral-song-231771.wav"
CHECKPOINT   = r"E:\Minor-Project-For-Amity-Patna\Model SEGAN\checkpoints_seagan\seagan_final.pt"

CHUNK_SECONDS      = 50.0         # chunk length (seconds)
CHUNK_OVERLAP      = 0.5          # fraction overlap (0.5 = 50%)
USE_SPECTRAL_GATE  = True         # final spectral gating to suppress residual noise
NOISE_FRAC         = 0.10        # fraction of lowest-energy frames to estimate noise floor
SUBTRACT_STRENGTH  = 1.0         # how strongly to subtract noise floor (1.0 = full)

# PNG packing options
PNG_WIDTH = 2048  # pixels per row for the packed PNG (adjust for memory/shape)
# ---------------------------------------------

class InferConfig(Config):
    ckpt_path = CHECKPOINT
    device = "cuda" if torch.cuda.is_available() else "cpu"

icfg = InferConfig()

# ---------------- utilities -------------------

def load_mono_resampled(path: str, target_sr: int):
    wav, sr = torchaudio.load(path)
    if wav.size(0) > 1:
        wav = wav.mean(dim=0, keepdim=True)
    if sr != target_sr:
        wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
        sr = target_sr
    return wav.squeeze(0)  # (T,)

def robust_save(path: str, wav_tensor: torch.Tensor, sr: int):
    x = wav_tensor.detach().cpu()
    if x.dim() == 1:
        x = x.unsqueeze(0)
    # remove leading unit batch dims
    while x.dim() > 2 and x.size(0) == 1:
        x = x.squeeze(0)
    if x.dim() > 2:
        x = torch.squeeze(x)
    if x.dim() == 1:
        x = x.unsqueeze(0)
    x = x.float()
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torchaudio.save(path, x, sr)
    print(f"Saved WAV: {path} (shape={tuple(x.shape)})")

def pad_or_crop_freq(mag: torch.Tensor, target_F: int):
    # mag: (1, F_mag, T)
    F_mag = mag.shape[1]
    if F_mag == target_F:
        return mag
    if F_mag < target_F:
        pad = target_F - F_mag
        return F.pad(mag, (0, 0, 0, pad))
    else:
        return mag[:, :target_F, :]

def mirror_pad_last_chunk(chunk: torch.Tensor, target_len: int):
    # chunk: (1,1,L)
    L = chunk.shape[-1]
    if L >= target_len:
        return chunk[:, :, :target_len]
    need = target_len - L
    frag = chunk[..., -min(L, need):].flip(-1)
    out = torch.cat([chunk, frag], dim=-1)
    if out.shape[-1] < target_len:
        out = F.pad(out, (0, target_len - out.shape[-1]))
    return out[:, :, :target_len]

# ---------------- spectral gating (final cleanup) ----------------

def spectral_subtract_and_reconstruct(waveform: torch.Tensor, stft_mod: STFTMagTransform, cfg: InferConfig,
                                      noise_frac=0.1, subtract_strength=1.0, device='cpu'):
    """
    waveform: (T,) or (1, T) tensor (CPU)
    Estimate noise floor from lowest-energy frames (fraction noise_frac),
    subtract (soft) from magnitude spectrogram, reconstruct with original phase via torch.istft.
    Returns denoised waveform (1, T_recon) on CPU.
    """
    if waveform.dim() == 1:
        wav = waveform.unsqueeze(0)  # (1, T)
    else:
        wav = waveform
    wav = wav.to(device)

    n_fft = cfg.n_fft
    hop = cfg.hop_length
    win = stft_mod.window.to(device)

    spec = torch.stft(wav, n_fft=n_fft, hop_length=hop, win_length=cfg.win_length, window=win, return_complex=True)
    mag = torch.abs(spec)     # (1, F, T)
    phase = torch.angle(spec) # (1, F, T)

    frame_energy = mag.pow(2).sum(dim=1).squeeze(0)  # (T,)
    n_frames = frame_energy.shape[-1]
    if n_frames <= 0:
        return wav.squeeze(0).cpu()

    k = max(1, int(n_frames * noise_frac))
    idxs = torch.argsort(frame_energy)[:k]
    noise_floor = mag[:, :, idxs].median(dim=-1).values  # (1, F)
    noise_floor_exp = noise_floor.unsqueeze(-1).repeat(1, 1, mag.shape[-1])

    alpha = subtract_strength
    mag_sub = mag - alpha * noise_floor_exp
    mag_sub = torch.clamp(mag_sub, min=0.0)

    real = mag_sub * torch.cos(phase)
    imag = mag_sub * torch.sin(phase)
    complex_sub = torch.complex(real, imag)

    recon = torch.istft(complex_sub, n_fft=n_fft, hop_length=hop, win_length=cfg.win_length, window=win, length=wav.shape[-1])
    return recon.squeeze(0).cpu()

# ---------------- core chunked denoiser (improved) ----------------

def denoise_chunked_final(input_path: str, output_path: str, cfg: InferConfig,
                          chunk_seconds=3.0, overlap=0.5,
                          use_spectral_gate=True, noise_frac=0.1, subtract_strength=1.0):
    device = cfg.device
    print("Device:", device)

    # load model + stft
    print("Loading checkpoint:", cfg.ckpt_path)
    ckpt = torch.load(cfg.ckpt_path, map_location=device)
    G = UNetGenerator(in_ch=1, out_ch=1).to(device)
    G.load_state_dict(ckpt["G_state"])
    G.eval()

    stft = STFTMagTransform(cfg.n_fft, cfg.hop_length, cfg.win_length).to(device)
    window = stft.window.to(device)

    # load audio
    wav = load_mono_resampled(input_path, cfg.sample_rate)  # (T,)
    T = wav.shape[0]
    sr = cfg.sample_rate
    print(f"Input: {T} samples ({T/sr:.2f} s) SR={sr}")

    chunk_samples = max(1, int(chunk_seconds * sr))
    hop = max(1, int(chunk_samples * (1.0 - overlap)))
    print(f"Chunk {chunk_samples} samples, hop {hop} samples")

    out_len = T + chunk_samples
    out_buffer = torch.zeros(out_len, dtype=torch.float32)
    weight_buffer = torch.zeros(out_len, dtype=torch.float32)

    synth_win = torch.hann_window(chunk_samples, periodic=True, dtype=torch.float32)

    idx = 0
    while idx < T:
        start = idx
        end = min(idx + chunk_samples, T)
        chunk = wav[start:end].unsqueeze(0).unsqueeze(0).to(device)  # (1,1,L)
        orig_len = chunk.shape[-1]
        if orig_len < chunk_samples:
            chunk = mirror_pad_last_chunk(chunk, chunk_samples).to(device)

        with torch.no_grad():
            spec = stft(chunk)             # (1,1,F_spec,Frames)
            fake = G(spec)                 # (1,1,F_fake,Frames)
            mag = torch.expm1(fake.clamp_min(0.0)).squeeze(1)  # (1,F_fake,Frames)

        chunk_1d = chunk.view(1, -1)
        complex_noisy = torch.stft(chunk_1d, n_fft=cfg.n_fft, hop_length=cfg.hop_length,
                                   win_length=cfg.win_length, window=window, return_complex=True)
        phase = torch.angle(complex_noisy)  # (1,F_phase,Frames_phase)

        n_frames_mag = mag.shape[-1]
        n_frames_phase = phase.shape[-1]
        min_frames = min(n_frames_mag, n_frames_phase)
        mag = mag[..., :min_frames]
        phase = phase[..., :min_frames]

        expected_F = cfg.n_fft // 2 + 1
        mag = pad_or_crop_freq(mag, expected_F)

        real = mag * torch.cos(phase)
        imag = mag * torch.sin(phase)
        complex_spec = torch.complex(real, imag).squeeze(0)  # (F, frames)

        wav_rec = torch.istft(complex_spec.unsqueeze(0).to(device),
                              n_fft=cfg.n_fft, hop_length=cfg.hop_length,
                              win_length=cfg.win_length, window=window,
                              length=chunk_samples).squeeze(0).cpu()

        if wav_rec.shape[-1] < chunk_samples:
            wav_rec = F.pad(wav_rec, (0, chunk_samples - wav_rec.shape[-1]))
        elif wav_rec.shape[-1] > chunk_samples:
            wav_rec = wav_rec[:chunk_samples]

        win = synth_win.clone().cpu()
        wav_rec_win = wav_rec * win

        write_start = start
        write_end = start + chunk_samples
        out_buffer[write_start:write_end] += wav_rec_win
        weight_buffer[write_start:write_end] += win

        idx += hop

    nonzero = weight_buffer > 1e-8
    out_buffer[nonzero] = out_buffer[nonzero] / weight_buffer[nonzero]
    denoised = out_buffer[:T].contiguous()

    if use_spectral_gate:
        print("Applying final spectral gating...")
        denoised = spectral_subtract_and_reconstruct(denoised.unsqueeze(0), stft, cfg,
                                                     noise_frac=noise_frac, subtract_strength=subtract_strength,
                                                     device=cfg.device)

    denoised = torch.clamp(denoised, -0.999, 0.999)

    robust_save(output_path, denoised, sr)

    # After saving WAV, pack denoised audio into lossless PNG
    png_path = os.path.splitext(output_path)[0] + "_packed.png"
    save_audio_as_png_lossless(denoised, png_path, width=PNG_WIDTH)
    print("Packed denoised audio into PNG:", png_path)

    # Optionally: reconstruct WAV from PNG to verify roundtrip correctness
    recon_wav = os.path.splitext(output_path)[0] + "_reconstructed_from_png.wav"
    restored = load_audio_from_png_lossless(png_path, original_length=denoised.shape[-1])
    write_wav_from_tensor(restored, recon_wav, sr)
    print("Reconstructed WAV from PNG:", recon_wav)

    return output_path, png_path, recon_wav

# === Lossless audio <-> PNG packing (bit-perfect) ===

def audio_tensor_to_int16_array(wav_tensor: torch.Tensor):
    """wav_tensor: 1D float tensor (-1..1) or (1, T). Returns numpy int16 1D array."""
    if isinstance(wav_tensor, torch.Tensor):
        x = wav_tensor.detach().cpu().numpy()
    else:
        x = np.asarray(wav_tensor)
    if x.ndim == 2 and x.shape[0] == 1:
        x = x[0]
    x = np.clip(x, -1.0, 1.0)
    int16 = (x * 32767.0).astype(np.int16)
    return int16

def int16_array_to_audio_tensor(int16_arr: np.ndarray):
    """Return torch.FloatTensor mono in range -1..1"""
    arr = np.asarray(int16_arr, dtype=np.int16)
    float32 = (arr.astype(np.float32) / 32767.0)
    return torch.from_numpy(float32)

def save_audio_as_png_lossless(wav_tensor: torch.Tensor, png_path: str, width: int = 2048):
    """
    Pack int16 PCM samples into a single-channel 16-bit PNG.
    width = number of pixels per row. height computed automatically.
    """
    samples = audio_tensor_to_int16_array(wav_tensor)
    N = samples.shape[0]
    height = math.ceil(N / width)
    total = width * height
    pad = total - N
    padded = np.pad(samples, (0, pad), mode='constant', constant_values=0).astype(np.int16)

    arr = padded.reshape((height, width))

    # reinterpret signed int16 bits as uint16 so PIL can save without changing bits
    uint16_view = arr.view(np.uint16)

    im = Image.fromarray(uint16_view, mode='I;16')
    os.makedirs(os.path.dirname(png_path), exist_ok=True)
    im.save(png_path, format='PNG')
    print(f"Saved lossless audio PNG: {png_path} (samples={N}, width={width}, height={height})")
    return png_path

def load_audio_from_png_lossless(png_path: str, original_length: int = None):
    """
    Read 16-bit PNG saved by save_audio_as_png_lossless and return torch.FloatTensor mono (-1..1).
    If original_length provided, crop the extra padding.
    """
    im = Image.open(png_path)
    arr_uint16 = np.array(im, dtype=np.uint16)
    int16_arr = arr_uint16.view(np.int16).reshape(-1)
    if original_length is not None:
        int16_arr = int16_arr[:original_length]
    float_tensor = int16_array_to_audio_tensor(int16_arr)
    return float_tensor  # 1D torch tensor

def write_wav_from_tensor(tensor: torch.Tensor, out_wav_path: str, sr: int):
    # tensor: 1D float -1..1
    x = tensor.detach().cpu().numpy()
    int16 = (np.clip(x, -1.0, 1.0) * 32767.0).astype(np.int16)
    os.makedirs(os.path.dirname(out_wav_path), exist_ok=True)
    with wave.open(out_wav_path, 'wb') as wf:
        wf.setnchannels(1)
        wf.setsampwidth(2)  # 16-bit
        wf.setframerate(sr)
        wf.writeframes(int16.tobytes())
    print(f"WAV written (lossless restore): {out_wav_path} (samples={int16.size}, sr={sr})")
    return out_wav_path

# ---------------- run ----------------

if __name__ == "__main__":
    print("Running final chunked denoiser -> PNG (lossless) ...")
    out_wav, out_png, out_recon = denoise_chunked_final(INPUT_AUDIO, OUTPUT_AUDIO, icfg,
                                                        chunk_seconds=CHUNK_SECONDS,
                                                        overlap=CHUNK_OVERLAP,
                                                        use_spectral_gate=USE_SPECTRAL_GATE,
                                                        noise_frac=NOISE_FRAC,
                                                        subtract_strength=SUBTRACT_STRENGTH)
    print("Done.")
    print("Denoised WAV:", out_wav)
    print("Packed PNG:", out_png)
    print("Reconstructed WAV from PNG (verification):", out_recon)