91prince commited on
Commit
7eadad0
·
1 Parent(s): 698dbb3

Add SEAGAN model code, pipeline, and large checkpoint file

Browse files
Files changed (6) hide show
  1. README.md +79 -0
  2. SEGAN.py +497 -0
  3. app.py +147 -0
  4. checkpoints/seagan_final.pt +3 -0
  5. pipeline.py +378 -0
  6. requirements.txt +7 -0
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SEAGAN Speech Enhancement & API
2
+ ===============================
3
+
4
+ A minimal speech-denoising project built around a SEGAN-style U-Net generator. It includes:
5
+ - Training script to learn on paired noisy/clean audio.
6
+ - Inference pipeline that denoises long clips in chunks and can pack output audio losslessly into PNG.
7
+ - FastAPI service to expose denoise + PNG pack/restore endpoints.
8
+
9
+ Repo Contents
10
+ -------------
11
+ - `SEGAN.py` – training components: config, dataset, U-Net generator, PatchGAN discriminator, training loop.
12
+ - `pipeline.py` – inference utilities: chunked denoiser, spectral gating cleanup, PNG pack/restore helpers.
13
+ - `app.py` – FastAPI app wiring the pipeline for HTTP use.
14
+ - `seagan_final.pt` – example checkpoint (place your own if different).
15
+ - `requirements.txt` – Python dependencies.
16
+
17
+ Prerequisites
18
+ -------------
19
+ - Python 3.9+ (tested with PyTorch CPU/GPU builds).
20
+ - For GPU inference/training, install the matching CUDA-enabled `torch`/`torchaudio`.
21
+ - FFmpeg is not required; `torchaudio` handles WAV I/O.
22
+
23
+ Install
24
+ -------
25
+ ```bash
26
+ python -m venv .venv
27
+ source .venv/Scripts/activate # on Windows PowerShell: .\.venv\Scripts\activate
28
+ pip install -r requirements.txt
29
+ ```
30
+ If you need a specific CUDA wheel, install torch/torchaudio first, then run `pip install -r requirements.txt` with `--no-deps`.
31
+
32
+ Quick Inference (CLI)
33
+ ---------------------
34
+ Use the chunked denoiser directly:
35
+ ```bash
36
+ python pipeline.py --input path/to/noisy.wav --output path/to/denoised.wav --checkpoint seagan_final.pt
37
+ ```
38
+ Notes:
39
+ - `--png-width` controls width when packing to PNG; omit `--no-pack` to also write `*_packed.png` and a reconstructed WAV check.
40
+ - The denoiser mirrors/overlaps chunks to reduce seams and optionally runs a spectral subtraction cleanup.
41
+
42
+ FastAPI Service
43
+ ---------------
44
+ Environment variables:
45
+ - `CHECKPOINT_PATH` (default `/app/checkpoints/seagan_final.pt`)
46
+ - `CHECKPOINT_URL` (optional download at startup)
47
+ - `SAMPLE_RATE` (default `16000`)
48
+ - `PNG_WIDTH` (default `2048`)
49
+
50
+ Run locally:
51
+ ```bash
52
+ uvicorn app:app --host 0.0.0.0 --port 8000
53
+ ```
54
+
55
+ Endpoints:
56
+ - `POST /denoise-and-pack` – form-data key `file` with WAV. Returns packed PNG of denoised audio.
57
+ - `POST /restore-from-png` – form-data key `file` with packed PNG. Returns restored WAV.
58
+ - `GET /health` – health check.
59
+
60
+ Model Training
61
+ --------------
62
+ `SEGAN.py` trains on paired noisy/clean WAVs. Update `Config.noisy_dir`, `Config.clean_dir`, and `Config.save_dir` to your paths, then run:
63
+ ```bash
64
+ python SEGAN.py
65
+ ```
66
+ Checkpoints are written every 5 epochs and as `seagan_final.pt` at the end. The inference pipeline expects a `G_state` entry inside the checkpoint.
67
+
68
+ PNG Packing/Restoration Utilities
69
+ ---------------------------------
70
+ `pipeline.py` exposes:
71
+ - `save_audio_as_png_lossless(tensor, png_path, width)` – stores int16 PCM in a lossless PNG.
72
+ - `load_audio_from_png_lossless(png_path, original_length)` – restores the tensor.
73
+ - `write_wav_from_tensor(tensor, out_wav_path, sr)` – writes mono WAV.
74
+
75
+ Tips
76
+ ----
77
+ - Keep input WAVs mono or they will be averaged to mono.
78
+ - Large files are chunked; adjust `chunk_seconds` and `overlap` in `denoise_chunked_final`.
79
+ - Ensure the checkpoint matches the model architecture in `SEGAN.py`.
SEGAN.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SEAGAN-style Speech Enhancement (Noise Removal) Training Script
4
+
5
+ - Generator: U-Net on log-magnitude spectrograms
6
+ - Discriminator: PatchGAN-style conditional (noisy + clean/enhanced)
7
+ - Loss: L1 (reconstruction) + adversarial (LSGAN)
8
+
9
+ Requirements:
10
+ pip install torch torchaudio numpy
11
+ """
12
+
13
+ import os
14
+ import glob
15
+ import random
16
+ from typing import List, Tuple
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.optim as optim
21
+ from torch.utils.data import Dataset, DataLoader
22
+
23
+ import torchaudio
24
+
25
+ # ==========================
26
+ # CONFIG
27
+ # ==========================
28
+
29
+ class Config:
30
+ # Paths (CHANGE THESE TO YOUR FOLDERS)
31
+ noisy_dir = r"E:\Minor-Project-For-Amity-Patna\Models\Audio Data\Noisy Data" # noisy wavs
32
+ clean_dir = r"E:\Minor-Project-For-Amity-Patna\Models\Audio Data\Noiseless Data" # clean wavs
33
+ save_dir = r"E:\Minor-Project-For-Amity-Patna\Model SEGAN\checkpoints_seagan"
34
+
35
+ # Audio
36
+ sample_rate = 16000
37
+ segment_seconds = 1.0 # train on 1-second chunks
38
+ mono = True
39
+
40
+ # STFT / Spectrogram
41
+ n_fft = 512
42
+ hop_length = 128
43
+ win_length = 512
44
+
45
+ # Training
46
+ batch_size = 8
47
+ num_workers = 2
48
+ num_epochs = 50
49
+ lr_g = 2e-4
50
+ lr_d = 2e-4
51
+ beta1 = 0.5
52
+ beta2 = 0.999
53
+
54
+ lambda_l1 = 100.0 # weight for L1 loss vs GAN loss (like pix2pix)
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+
57
+ cfg = Config()
58
+
59
+
60
+ # ==========================
61
+ # DATASET
62
+ # ==========================
63
+
64
+ def list_wav_pairs(noisy_dir: str, clean_dir: str) -> List[Tuple[str, str]]:
65
+ noisy_files = sorted(glob.glob(os.path.join(noisy_dir, "*.wav")))
66
+ pairs = []
67
+ for nf in noisy_files:
68
+ name = os.path.basename(nf)
69
+ cf = os.path.join(clean_dir, name)
70
+ if os.path.exists(cf):
71
+ pairs.append((nf, cf))
72
+ return pairs
73
+
74
+
75
+ class SEAGANDataset(Dataset):
76
+ def __init__(
77
+ self,
78
+ noisy_dir: str,
79
+ clean_dir: str,
80
+ sample_rate: int = 16000,
81
+ segment_seconds: float = 1.0,
82
+ ):
83
+ self.sample_rate = sample_rate
84
+ self.segment_samples = int(segment_seconds * sample_rate)
85
+
86
+ self.pairs = list_wav_pairs(noisy_dir, clean_dir)
87
+ if len(self.pairs) == 0:
88
+ raise RuntimeError("No paired .wav files found! Check your folders & names.")
89
+
90
+ self.resampler_cache = {}
91
+
92
+ def __len__(self):
93
+ return len(self.pairs)
94
+
95
+ def _get_resampler(self, orig_sr: int):
96
+ if orig_sr == self.sample_rate:
97
+ return None
98
+ if orig_sr not in self.resampler_cache:
99
+ self.resampler_cache[orig_sr] = torchaudio.transforms.Resample(
100
+ orig_freq=orig_sr, new_freq=self.sample_rate
101
+ )
102
+ return self.resampler_cache[orig_sr]
103
+
104
+ def _load_audio(self, path: str) -> torch.Tensor:
105
+ wav, sr = torchaudio.load(path) # shape: (channels, samples)
106
+ if wav.shape[0] > 1:
107
+ wav = wav.mean(dim=0, keepdim=True) # mono
108
+ resampler = self._get_resampler(sr)
109
+ if resampler is not None:
110
+ wav = resampler(wav)
111
+ return wav # (1, samples)
112
+
113
+ def _aligned_random_crop(self, noisy: torch.Tensor, clean: torch.Tensor):
114
+ """
115
+ Crop noisy and clean with the same start index for alignment.
116
+ noisy, clean: (1, T)
117
+ """
118
+ T = min(noisy.shape[1], clean.shape[1])
119
+ noisy = noisy[:, :T]
120
+ clean = clean[:, :T]
121
+
122
+ if T <= self.segment_samples:
123
+ pad = self.segment_samples - T
124
+ noisy = torch.nn.functional.pad(noisy, (0, pad))
125
+ clean = torch.nn.functional.pad(clean, (0, pad))
126
+ return noisy, clean
127
+ else:
128
+ start = random.randint(0, T - self.segment_samples)
129
+ end = start + self.segment_samples
130
+ return noisy[:, start:end], clean[:, start:end]
131
+
132
+ def __getitem__(self, idx: int):
133
+ noisy_path, clean_path = self.pairs[idx]
134
+
135
+ noisy = self._load_audio(noisy_path)
136
+ clean = self._load_audio(clean_path)
137
+
138
+ noisy, clean = self._aligned_random_crop(noisy, clean)
139
+
140
+ return noisy, clean
141
+
142
+
143
+ # ==========================
144
+ # SPECTROGRAM HELPERS
145
+ # ==========================
146
+
147
+ class STFTMagTransform(nn.Module):
148
+ """
149
+ Convert waveform -> log-magnitude spectrogram
150
+ """
151
+
152
+ def __init__(self, n_fft, hop_length, win_length):
153
+ super().__init__()
154
+ self.n_fft = n_fft
155
+ self.hop_length = hop_length
156
+ self.win_length = win_length
157
+
158
+ # register window so it moves with .to(device)
159
+ self.register_buffer("window", torch.hann_window(win_length))
160
+
161
+ def forward(self, wav: torch.Tensor) -> torch.Tensor:
162
+ """
163
+ wav: (B, 1, T)
164
+ return: (B, 1, F, T_spec)
165
+ """
166
+ B, C, T = wav.shape
167
+
168
+ wav = wav.view(B * C, T)
169
+ spec = torch.stft(
170
+ wav,
171
+ n_fft=self.n_fft,
172
+ hop_length=self.hop_length,
173
+ win_length=self.win_length,
174
+ window=self.window,
175
+ return_complex=True,
176
+ )
177
+ mag = torch.abs(spec) # (B*C, F, T_spec)
178
+ log_mag = torch.log1p(mag) # log(1 + mag)
179
+ log_mag = log_mag.view(B, C, log_mag.shape[1], log_mag.shape[2])
180
+ return log_mag
181
+
182
+
183
+ # ==========================
184
+ # SIZE MATCH HELPER
185
+ # ==========================
186
+
187
+ def match_size(a: torch.Tensor, b: torch.Tensor):
188
+ """
189
+ Crop a and b to have the same (H, W). Keeps the top-left region.
190
+ a, b: (..., H, W)
191
+ returns: (a_crop, b_crop)
192
+ """
193
+ Ha, Wa = a.shape[-2], a.shape[-1]
194
+ Hb, Wb = b.shape[-2], b.shape[-1]
195
+ H = min(Ha, Hb)
196
+ W = min(Wa, Wb)
197
+ a_c = a[..., :H, :W]
198
+ b_c = b[..., :H, :W]
199
+ return a_c, b_c
200
+
201
+
202
+ # ==========================
203
+ # GENERATOR (U-NET)
204
+ # ==========================
205
+
206
+ class ConvBlock(nn.Module):
207
+ def __init__(self, in_ch, out_ch, down=True, use_bn=True):
208
+ super().__init__()
209
+ if down:
210
+ layers = [
211
+ nn.Conv2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1),
212
+ nn.LeakyReLU(0.2, inplace=True),
213
+ ]
214
+ else:
215
+ layers = [
216
+ nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1),
217
+ nn.ReLU(inplace=True),
218
+ ]
219
+
220
+ if use_bn:
221
+ layers.insert(1, nn.BatchNorm2d(out_ch))
222
+
223
+ self.block = nn.Sequential(*layers)
224
+
225
+ def forward(self, x):
226
+ return self.block(x)
227
+
228
+
229
+ class UNetGenerator(nn.Module):
230
+ """
231
+ U-Net operating on (B, 1, F, T) log-magnitude spectrograms
232
+ """
233
+ def __init__(self, in_ch=1, out_ch=1, base_ch=64):
234
+ super().__init__()
235
+
236
+ # Encoder
237
+ self.down1 = ConvBlock(in_ch, base_ch, down=True, use_bn=False) # (64)
238
+ self.down2 = ConvBlock(base_ch, base_ch * 2)
239
+ self.down3 = ConvBlock(base_ch * 2, base_ch * 4)
240
+ self.down4 = ConvBlock(base_ch * 4, base_ch * 8)
241
+ self.down5 = ConvBlock(base_ch * 8, base_ch * 8)
242
+
243
+ # Bottleneck
244
+ self.bottleneck = nn.Sequential(
245
+ nn.Conv2d(base_ch * 8, base_ch * 8, kernel_size=4, stride=2, padding=1),
246
+ nn.ReLU(inplace=True),
247
+ )
248
+
249
+ # Decoder
250
+ self.up1 = ConvBlock(base_ch * 8, base_ch * 8, down=False)
251
+ self.up2 = ConvBlock(base_ch * 8 * 2, base_ch * 8, down=False)
252
+ self.up3 = ConvBlock(base_ch * 8 * 2, base_ch * 4, down=False)
253
+ self.up4 = ConvBlock(base_ch * 4 * 2, base_ch * 2, down=False)
254
+ self.up5 = ConvBlock(base_ch * 2 * 2, base_ch, down=False)
255
+
256
+ self.final = nn.ConvTranspose2d(
257
+ base_ch * 2, out_ch, kernel_size=4, stride=2, padding=1
258
+ )
259
+ # Output non-negative log-magnitude
260
+ self.out_act = nn.ReLU()
261
+
262
+ def _crop_to(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
263
+ """
264
+ Center-crop src to have the same H, W as tgt.
265
+ src: (B, C, Hs, Ws)
266
+ tgt: (B, C, Ht, Wt) (only Ht, Wt are used)
267
+ """
268
+ _, _, Hs, Ws = src.shape
269
+ _, _, Ht, Wt = tgt.shape
270
+
271
+ if Hs == Ht and Ws == Wt:
272
+ return src
273
+
274
+ start_h = max((Hs - Ht) // 2, 0)
275
+ start_w = max((Ws - Wt) // 2, 0)
276
+ end_h = start_h + Ht
277
+ end_w = start_w + Wt
278
+
279
+ return src[:, :, start_h:end_h, start_w:end_w]
280
+
281
+ def forward(self, x):
282
+ # encoder
283
+ d1 = self.down1(x) # B,64
284
+ d2 = self.down2(d1) # B,128
285
+ d3 = self.down3(d2) # B,256
286
+ d4 = self.down4(d3) # B,512
287
+ d5 = self.down5(d4) # B,512
288
+
289
+ bott = self.bottleneck(d5)
290
+
291
+ # decoder with crops + skips
292
+ u1 = self.up1(bott)
293
+ d5_c = self._crop_to(d5, u1)
294
+ u1 = torch.cat([u1, d5_c], dim=1)
295
+
296
+ u2 = self.up2(u1)
297
+ d4_c = self._crop_to(d4, u2)
298
+ u2 = torch.cat([u2, d4_c], dim=1)
299
+
300
+ u3 = self.up3(u2)
301
+ d3_c = self._crop_to(d3, u3)
302
+ u3 = torch.cat([u3, d3_c], dim=1)
303
+
304
+ u4 = self.up4(u3)
305
+ d2_c = self._crop_to(d2, u4)
306
+ u4 = torch.cat([u4, d2_c], dim=1)
307
+
308
+ u5 = self.up5(u4)
309
+ d1_c = self._crop_to(d1, u5)
310
+ u5 = torch.cat([u5, d1_c], dim=1)
311
+
312
+ out = self.final(u5)
313
+ out = self.out_act(out) # non-negative log-magnitude
314
+ return out
315
+
316
+
317
+ # ==========================
318
+ # DISCRIMINATOR (PatchGAN)
319
+ # ==========================
320
+
321
+ class PatchDiscriminator(nn.Module):
322
+ """
323
+ Conditional discriminator: input = concat(noisy_spec, clean_or_fake_spec)
324
+ """
325
+ def __init__(self, in_ch=2, base_ch=64):
326
+ super().__init__()
327
+ # no batchnorm in first layer
328
+ self.model = nn.Sequential(
329
+ nn.Conv2d(in_ch, base_ch, kernel_size=4, stride=2, padding=1),
330
+ nn.LeakyReLU(0.2, inplace=True),
331
+
332
+ nn.Conv2d(base_ch, base_ch * 2, kernel_size=4, stride=2, padding=1),
333
+ nn.BatchNorm2d(base_ch * 2),
334
+ nn.LeakyReLU(0.2, inplace=True),
335
+
336
+ nn.Conv2d(base_ch * 2, base_ch * 4, kernel_size=4, stride=2, padding=1),
337
+ nn.BatchNorm2d(base_ch * 4),
338
+ nn.LeakyReLU(0.2, inplace=True),
339
+
340
+ nn.Conv2d(base_ch * 4, base_ch * 8, kernel_size=4, stride=1, padding=1),
341
+ nn.BatchNorm2d(base_ch * 8),
342
+ nn.LeakyReLU(0.2, inplace=True),
343
+
344
+ nn.Conv2d(base_ch * 8, 1, kernel_size=4, stride=1, padding=1),
345
+ # no activation -> LSGAN
346
+ )
347
+
348
+ def forward(self, x):
349
+ return self.model(x) # (B, 1, H', W')
350
+
351
+
352
+ # ==========================
353
+ # TRAINING
354
+ # ==========================
355
+
356
+ def save_checkpoint(epoch, G, D, opt_g, opt_d, path):
357
+ os.makedirs(os.path.dirname(path), exist_ok=True)
358
+ torch.save(
359
+ {
360
+ "epoch": epoch,
361
+ "G_state": G.state_dict(),
362
+ "D_state": D.state_dict(),
363
+ "opt_g_state": opt_g.state_dict(),
364
+ "opt_d_state": opt_d.state_dict(),
365
+ },
366
+ path,
367
+ )
368
+ print(f"Saved checkpoint: {path}")
369
+
370
+
371
+ def train():
372
+ device = cfg.device
373
+ print(f"Using device: {device}")
374
+
375
+ dataset = SEAGANDataset(
376
+ cfg.noisy_dir, cfg.clean_dir, cfg.sample_rate, cfg.segment_seconds
377
+ )
378
+ loader = DataLoader(
379
+ dataset,
380
+ batch_size=cfg.batch_size,
381
+ shuffle=True,
382
+ num_workers=cfg.num_workers,
383
+ drop_last=True,
384
+ )
385
+
386
+ stft_transform = STFTMagTransform(
387
+ cfg.n_fft, cfg.hop_length, cfg.win_length
388
+ ).to(device)
389
+
390
+ G = UNetGenerator(in_ch=1, out_ch=1).to(device)
391
+ D = PatchDiscriminator(in_ch=2).to(device)
392
+
393
+ # LSGAN loss
394
+ criterion_gan = nn.MSELoss()
395
+ criterion_l1 = nn.L1Loss()
396
+
397
+ opt_g = optim.Adam(G.parameters(), lr=cfg.lr_g, betas=(cfg.beta1, cfg.beta2))
398
+ opt_d = optim.Adam(D.parameters(), lr=cfg.lr_d, betas=(cfg.beta1, cfg.beta2))
399
+
400
+ for epoch in range(1, cfg.num_epochs + 1):
401
+ G.train()
402
+ D.train()
403
+
404
+ running_g_loss = 0.0
405
+ running_d_loss = 0.0
406
+
407
+ for i, (noisy_wav, clean_wav) in enumerate(loader):
408
+ noisy_wav = noisy_wav.to(device) # (B,1,T)
409
+ clean_wav = clean_wav.to(device) # (B,1,T)
410
+
411
+ # -------------------------
412
+ # Waveform -> Spectrogram
413
+ # -------------------------
414
+ noisy_spec = stft_transform(noisy_wav) # (B,1,F,T_spec)
415
+ clean_spec = stft_transform(clean_wav) # (B,1,F,T_spec)
416
+
417
+ # Ensure same size for real pair
418
+ noisy_spec, clean_spec = match_size(noisy_spec, clean_spec)
419
+
420
+ # =========================
421
+ # Train Discriminator
422
+ # =========================
423
+ opt_d.zero_grad()
424
+
425
+ # Real pair: (noisy, clean)
426
+ real_input = torch.cat([noisy_spec, clean_spec], dim=1)
427
+ pred_real = D(real_input)
428
+ target_real = torch.ones_like(pred_real)
429
+ loss_d_real = criterion_gan(pred_real, target_real)
430
+
431
+ # Fake pair: (noisy, enhanced)
432
+ with torch.no_grad():
433
+ fake_spec = G(noisy_spec)
434
+ # match noisy and fake sizes
435
+ noisy_for_fake_d, fake_spec_d = match_size(noisy_spec, fake_spec)
436
+ fake_input = torch.cat([noisy_for_fake_d, fake_spec_d], dim=1)
437
+ pred_fake = D(fake_input)
438
+ target_fake = torch.zeros_like(pred_fake)
439
+ loss_d_fake = criterion_gan(pred_fake, target_fake)
440
+
441
+ loss_d = 0.5 * (loss_d_real + loss_d_fake)
442
+ loss_d.backward()
443
+ opt_d.step()
444
+
445
+ # =========================
446
+ # Train Generator
447
+ # =========================
448
+ opt_g.zero_grad()
449
+
450
+ fake_spec = G(noisy_spec)
451
+
452
+ # GAN loss (want D(noisy, fake) = 1)
453
+ noisy_for_fake_g, fake_spec_g = match_size(noisy_spec, fake_spec)
454
+ fake_input_g = torch.cat([noisy_for_fake_g, fake_spec_g], dim=1)
455
+ pred_fake_for_g = D(fake_input_g)
456
+ target_real_for_g = torch.ones_like(pred_fake_for_g)
457
+ loss_g_gan = criterion_gan(pred_fake_for_g, target_real_for_g)
458
+
459
+ # L1 reconstruction loss (match fake & clean sizes)
460
+ fake_l1, clean_l1 = match_size(fake_spec, clean_spec)
461
+ loss_g_l1 = criterion_l1(fake_l1, clean_l1) * cfg.lambda_l1
462
+
463
+ loss_g = loss_g_gan + loss_g_l1
464
+ loss_g.backward()
465
+ opt_g.step()
466
+
467
+ running_d_loss += loss_d.item()
468
+ running_g_loss += loss_g.item()
469
+
470
+ if (i + 1) % 20 == 0:
471
+ print(
472
+ f"Epoch [{epoch}/{cfg.num_epochs}] "
473
+ f"Step [{i+1}/{len(loader)}] "
474
+ f"D Loss: {loss_d.item():.4f} "
475
+ f"G Loss: {loss_g.item():.4f} "
476
+ f"(GAN: {loss_g_gan.item():.4f}, L1: {loss_g_l1.item():.4f})"
477
+ )
478
+
479
+ avg_d = running_d_loss / len(loader)
480
+ avg_g = running_g_loss / len(loader)
481
+ print(
482
+ f"==> Epoch {epoch} finished | "
483
+ f"Avg D Loss: {avg_d:.4f} | Avg G Loss: {avg_g:.4f}"
484
+ )
485
+
486
+ # save checkpoint every few epochs
487
+ if epoch % 5 == 0:
488
+ ckpt_path = os.path.join(cfg.save_dir, f"seagan_epoch_{epoch}.pt")
489
+ save_checkpoint(epoch, G, D, opt_g, opt_d, ckpt_path)
490
+
491
+ # final save
492
+ ckpt_path = os.path.join(cfg.save_dir, f"seagan_final.pt")
493
+ save_checkpoint(cfg.num_epochs, G, D, opt_g, opt_d, ckpt_path)
494
+
495
+
496
+ if __name__ == "__main__":
497
+ train()
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import io
4
+ import uvicorn
5
+ import torch
6
+ import tempfile
7
+ from fastapi import FastAPI, UploadFile, File, HTTPException
8
+ from fastapi.responses import FileResponse
9
+ from starlette.middleware.cors import CORSMiddleware
10
+
11
+ # --- Import your denoiser functions (adjust import if SEGAN.py is in subfolder) ---
12
+ # from SEGAN import Config, STFTMagTransform, UNetGenerator
13
+ # from your_denoiser_module import denoise_chunked_final, save_audio_as_png_lossless, load_audio_from_png_lossless, write_wav_from_tensor
14
+ # For clarity, this file assumes denoise_chunked_final and packing functions are available in the `pipeline` module.
15
+ from pipeline import InferConfig, denoise_chunked_final, save_audio_as_png_lossless, load_audio_from_png_lossless, write_wav_from_tensor
16
+
17
+ # --- Config from env ---
18
+ CHECKPOINT = os.environ.get("CHECKPOINT_PATH", "/app/checkpoints/seagan_final.pt")
19
+ CHECKPOINT_URL = os.environ.get("CHECKPOINT_URL") # optional: download at startup
20
+ SAMPLE_RATE = int(os.environ.get("SAMPLE_RATE", "16000"))
21
+ PNG_WIDTH = int(os.environ.get("PNG_WIDTH", "2048"))
22
+
23
+ # Create directories
24
+ os.makedirs("/app/data", exist_ok=True)
25
+ os.makedirs("/app/checkpoints", exist_ok=True)
26
+ os.makedirs("/tmp", exist_ok=True)
27
+
28
+ app = FastAPI(title="SEGAN Denoise + PNG packer API")
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"],
32
+ allow_methods=["*"],
33
+ allow_headers=["*"],
34
+ )
35
+
36
+ # Download checkpoint if provided via URL and not present
37
+ def ensure_checkpoint():
38
+ if os.path.isfile(CHECKPOINT):
39
+ print("Checkpoint exists:", CHECKPOINT)
40
+ return CHECKPOINT
41
+ if CHECKPOINT_URL:
42
+ import requests
43
+ print("Downloading checkpoint from URL...")
44
+ r = requests.get(CHECKPOINT_URL, stream=True, timeout=60)
45
+ if r.status_code != 200:
46
+ raise RuntimeError("Failed to download checkpoint; status=" + str(r.status_code))
47
+ outp = CHECKPOINT
48
+ os.makedirs(os.path.dirname(outp), exist_ok=True)
49
+ with open(outp, "wb") as f:
50
+ for chunk in r.iter_content(chunk_size=8192):
51
+ f.write(chunk)
52
+ print("Downloaded checkpoint to", outp)
53
+ return outp
54
+ raise FileNotFoundError("No checkpoint found; set CHECKPOINT_PATH or CHECKPOINT_URL environment variable.")
55
+
56
+ # Initialize model config object (pipeline expects an InferConfig from your SEGAN code)
57
+ icfg = InferConfig() # make sure this respects ckpt path in env inside your class
58
+ icfg.ckpt_path = CHECKPOINT
59
+
60
+ @app.on_event("startup")
61
+ def startup_event():
62
+ # ensure checkpoint present
63
+ try:
64
+ cp = ensure_checkpoint()
65
+ except Exception as e:
66
+ print("Warning: checkpoint not found at startup:", e)
67
+ print("Startup complete.")
68
+
69
+ @app.post("/denoise-and-pack")
70
+ async def denoise_and_pack(file: UploadFile = File(...)):
71
+ """
72
+ Accepts a WAV file upload. Returns a packed PNG containing lossless int16 PCM of denoised audio.
73
+ Form-data key: 'file'
74
+ """
75
+ # Accept only audio/wav or octet-stream
76
+ if file.content_type not in ("audio/wav", "audio/x-wav", "application/octet-stream"):
77
+ # still accept many clients — but warn
78
+ print("Warning: uploaded content_type:", file.content_type)
79
+
80
+ # Save upload to temp WAV file
81
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
82
+ tmp_in.write(await file.read())
83
+ tmp_in.flush()
84
+ tmp_in_path = tmp_in.name
85
+
86
+ # Prepare output paths
87
+ base = os.path.splitext(os.path.basename(tmp_in_path))[0]
88
+ out_wav_path = f"/app/data/{base}_denoised.wav"
89
+ out_png_path = f"/app/data/{base}_packed.png"
90
+ # Run denoiser & packer (this function should save WAV and pack PNG; returns paths)
91
+ try:
92
+ print("Running denoiser for:", tmp_in_path)
93
+ # Denoser might be heavy — run on CPU if no GPU
94
+ out = denoise_chunked_final(tmp_in_path, out_wav_path, icfg,
95
+ chunk_seconds=50.0, overlap=0.5,
96
+ use_spectral_gate=True, noise_frac=0.1, subtract_strength=1.0)
97
+ # out may be (wav_path, png_path, recon_wav) depending on your pipeline
98
+ except Exception as e:
99
+ print("Denoiser error:", e)
100
+ raise HTTPException(status_code=500, detail="Denoiser failed: " + str(e))
101
+
102
+ # If your denoiser already wrote packed PNG, send that; else pack
103
+ if os.path.exists(out_png_path):
104
+ png_to_send = out_png_path
105
+ else:
106
+ # load denoised tensor (you may adapt this to how denoiser returns data)
107
+ # The pipeline.save_audio_as_png_lossless takes a tensor; if you only have file, use torchaudio.load
108
+ import torchaudio
109
+ wav, sr = torchaudio.load(out_wav_path)
110
+ if wav.size(0) > 1:
111
+ wav = wav.mean(dim=0, keepdim=True)
112
+ wav1d = wav.squeeze(0)
113
+ save_audio_as_png_lossless(wav1d, out_png_path, width=PNG_WIDTH)
114
+ png_to_send = out_png_path
115
+
116
+ return FileResponse(png_to_send, media_type="image/png", filename=os.path.basename(png_to_send))
117
+
118
+ @app.post("/restore-from-png")
119
+ async def restore_from_png(file: UploadFile = File(...)):
120
+ """
121
+ Accept a packed PNG upload and return restored WAV (mono int16) using SAMPLE_RATE env var.
122
+ """
123
+ if file.content_type not in ("image/png", "application/octet-stream"):
124
+ print("Warning: uploaded content_type:", file.content_type)
125
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_png:
126
+ tmp_png.write(await file.read())
127
+ tmp_png.flush()
128
+ tmp_png_path = tmp_png.name
129
+
130
+ try:
131
+ restored_tensor = load_audio_from_png_lossless(tmp_png_path, original_length=None)
132
+ out_wav = f"/app/data/restored_{os.path.basename(tmp_png_path)}.wav"
133
+ write_wav_from_tensor(restored_tensor, out_wav, SAMPLE_RATE)
134
+ except Exception as e:
135
+ print("Restore error:", e)
136
+ raise HTTPException(status_code=500, detail="Restore failed: " + str(e))
137
+
138
+ return FileResponse(out_wav, media_type="audio/wav", filename=os.path.basename(out_wav))
139
+
140
+ # Optional simple healthcheck
141
+ @app.get("/health")
142
+ def health():
143
+ return {"status": "ok"}
144
+
145
+ # Run when invoked directly (development)
146
+ if __name__ == "__main__":
147
+ uvicorn.run("app:app", host="0.0.0.0", port=int(os.environ.get("PORT", 8000)))
checkpoints/seagan_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9641f1516a0123767e1684f85b09f9fc919949f7104983619bdb5088e815dae8
3
+ size 384194538
pipeline.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ pipeline.py
4
+
5
+ Contains:
6
+ - InferConfig (wraps your SEGAN.Config)
7
+ - denoise_chunked_final(...) -> denoised WAV path, packed PNG path, reconstructed WAV path
8
+ - save_audio_as_png_lossless / load_audio_from_png_lossless / write_wav_from_tensor
9
+ - helper utilities used by the denoiser (robust_save, mirror-pad, spectral gating)
10
+
11
+ Usage: import the functions in your FastAPI `app.py` or run this file directly for a local test.
12
+
13
+ Note: this module expects your SEGAN.py (containing Config, STFTMagTransform, UNetGenerator)
14
+ to be available in the same directory or in PYTHONPATH. Adjust imports if needed.
15
+ """
16
+
17
+ import os
18
+ import math
19
+ import wave
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import torchaudio
23
+ import numpy as np
24
+ from PIL import Image
25
+
26
+ # Try to import SEGAN components - user must have SEGAN.py in same folder or package
27
+ try:
28
+ from SEGAN import Config, STFTMagTransform, UNetGenerator
29
+ except Exception as e:
30
+ # If import fails, raise a clear error when functions are used; keep module importable for tools that
31
+ # just want pack/unpack functions.
32
+ Config = None
33
+ STFTMagTransform = None
34
+ UNetGenerator = None
35
+ _import_error = e
36
+
37
+
38
+ # ----------------- Configuration (defaults) -----------------
39
+ DEFAULT_CHECKPOINT = os.environ.get("CHECKPOINT_PATH", "./checkpoints/seagan_final.pt")
40
+
41
+ # ----------------- Infer config wrapper ---------------------
42
+ class InferConfig:
43
+ """Simple wrapper for your SEGAN.Config. If SEGAN.Config is available we use it; else provide defaults.
44
+ Attributes expected by the pipeline: ckpt_path, device, n_fft, hop_length, win_length, sample_rate
45
+ """
46
+ def __init__(self,
47
+ ckpt_path: str = DEFAULT_CHECKPOINT,
48
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
49
+ n_fft: int = 1024,
50
+ hop_length: int = 256,
51
+ win_length: int = 1024,
52
+ sample_rate: int = 16000):
53
+ # If real SEGAN.Config exists, instantiate it and override ckpt_path + device
54
+ if Config is not None:
55
+ try:
56
+ cfg = Config()
57
+ cfg.ckpt_path = ckpt_path
58
+ cfg.device = device
59
+ # keep other fields from Config if present
60
+ self.__dict__.update(cfg.__dict__)
61
+ return
62
+ except Exception:
63
+ # fall through to default fields
64
+ pass
65
+ # fallback defaults
66
+ self.ckpt_path = ckpt_path
67
+ self.device = device
68
+ self.n_fft = n_fft
69
+ self.hop_length = hop_length
70
+ self.win_length = win_length
71
+ self.sample_rate = sample_rate
72
+
73
+
74
+ # ---------------- utilities -------------------
75
+
76
+ def load_mono_resampled(path: str, target_sr: int):
77
+ wav, sr = torchaudio.load(path)
78
+ if wav.size(0) > 1:
79
+ wav = wav.mean(dim=0, keepdim=True)
80
+ if sr != target_sr:
81
+ wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
82
+ sr = target_sr
83
+ return wav.squeeze(0) # (T,)
84
+
85
+
86
+ def robust_save(path: str, wav_tensor: torch.Tensor, sr: int):
87
+ x = wav_tensor.detach().cpu()
88
+ if x.dim() == 1:
89
+ x = x.unsqueeze(0)
90
+ while x.dim() > 2 and x.size(0) == 1:
91
+ x = x.squeeze(0)
92
+ if x.dim() > 2:
93
+ x = torch.squeeze(x)
94
+ if x.dim() == 1:
95
+ x = x.unsqueeze(0)
96
+ x = x.float()
97
+ os.makedirs(os.path.dirname(path), exist_ok=True)
98
+ torchaudio.save(path, x, sr)
99
+ print(f"Saved WAV: {path} (shape={tuple(x.shape)})")
100
+
101
+
102
+ def pad_or_crop_freq(mag: torch.Tensor, target_F: int):
103
+ F_mag = mag.shape[1]
104
+ if F_mag == target_F:
105
+ return mag
106
+ if F_mag < target_F:
107
+ pad = target_F - F_mag
108
+ return F.pad(mag, (0, 0, 0, pad))
109
+ else:
110
+ return mag[:, :target_F, :]
111
+
112
+
113
+ def mirror_pad_last_chunk(chunk: torch.Tensor, target_len: int):
114
+ L = chunk.shape[-1]
115
+ if L >= target_len:
116
+ return chunk[:, :, :target_len]
117
+ need = target_len - L
118
+ frag = chunk[..., -min(L, need):].flip(-1)
119
+ out = torch.cat([chunk, frag], dim=-1)
120
+ if out.shape[-1] < target_len:
121
+ out = F.pad(out, (0, target_len - out.shape[-1]))
122
+ return out[:, :, :target_len]
123
+
124
+
125
+ # ---------------- spectral gating (final cleanup) ----------------
126
+
127
+ def spectral_subtract_and_reconstruct(waveform: torch.Tensor, stft_mod, cfg: InferConfig,
128
+ noise_frac=0.1, subtract_strength=1.0, device='cpu'):
129
+ if waveform.dim() == 1:
130
+ wav = waveform.unsqueeze(0) # (1, T)
131
+ else:
132
+ wav = waveform
133
+ wav = wav.to(device)
134
+
135
+ n_fft = cfg.n_fft
136
+ hop = cfg.hop_length
137
+ win = stft_mod.window.to(device) if stft_mod is not None else torch.hann_window(cfg.win_length).to(device)
138
+
139
+ spec = torch.stft(wav, n_fft=n_fft, hop_length=hop, win_length=cfg.win_length, window=win, return_complex=True)
140
+ mag = torch.abs(spec) # (1, F, T)
141
+ phase = torch.angle(spec) # (1, F, T)
142
+
143
+ frame_energy = mag.pow(2).sum(dim=1).squeeze(0) # (T,)
144
+ n_frames = frame_energy.shape[-1]
145
+ if n_frames <= 0:
146
+ return wav.squeeze(0).cpu()
147
+
148
+ k = max(1, int(n_frames * noise_frac))
149
+ idxs = torch.argsort(frame_energy)[:k]
150
+ noise_floor = mag[:, :, idxs].median(dim=-1).values # (1, F)
151
+ noise_floor_exp = noise_floor.unsqueeze(-1).repeat(1, 1, mag.shape[-1])
152
+
153
+ alpha = subtract_strength
154
+ mag_sub = mag - alpha * noise_floor_exp
155
+ mag_sub = torch.clamp(mag_sub, min=0.0)
156
+
157
+ real = mag_sub * torch.cos(phase)
158
+ imag = mag_sub * torch.sin(phase)
159
+ complex_sub = torch.complex(real, imag)
160
+
161
+ recon = torch.istft(complex_sub, n_fft=n_fft, hop_length=hop, win_length=cfg.win_length, window=win, length=wav.shape[-1])
162
+ return recon.squeeze(0).cpu()
163
+
164
+
165
+ # ---------------- core chunked denoiser (improved) ----------------
166
+
167
+ def denoise_chunked_final(input_path: str, output_path: str, cfg: InferConfig,
168
+ chunk_seconds=3.0, overlap=0.5,
169
+ use_spectral_gate=True, noise_frac=0.1, subtract_strength=1.0,
170
+ pack_png=True, png_width=2048):
171
+ """
172
+ Runs the chunked denoiser using the SEGAN generator.
173
+ Returns tuple: (out_wav_path, packed_png_path_or_None, recon_wav_path_or_None)
174
+ """
175
+ device = cfg.device
176
+ print("Device:", device)
177
+
178
+ # Check SEGAN availability
179
+ if UNetGenerator is None or STFTMagTransform is None or Config is None:
180
+ raise RuntimeError(f"SEGAN components not available. Original import error: {_import_error}")
181
+
182
+ # load model + stft
183
+ print("Loading checkpoint:", cfg.ckpt_path)
184
+ ckpt = torch.load(cfg.ckpt_path, map_location=device)
185
+ G = UNetGenerator(in_ch=1, out_ch=1).to(device)
186
+ G.load_state_dict(ckpt["G_state"])
187
+ G.eval()
188
+
189
+ stft = STFTMagTransform(cfg.n_fft, cfg.hop_length, cfg.win_length).to(device)
190
+ window = stft.window.to(device)
191
+
192
+ # load audio
193
+ wav = load_mono_resampled(input_path, cfg.sample_rate) # (T,)
194
+ T = wav.shape[0]
195
+ sr = cfg.sample_rate
196
+ print(f"Input: {T} samples ({T/sr:.2f} s) SR={sr}")
197
+
198
+ chunk_samples = max(1, int(chunk_seconds * sr))
199
+ hop = max(1, int(chunk_samples * (1.0 - overlap)))
200
+ print(f"Chunk {chunk_samples} samples, hop {hop} samples")
201
+
202
+ out_len = T + chunk_samples
203
+ out_buffer = torch.zeros(out_len, dtype=torch.float32)
204
+ weight_buffer = torch.zeros(out_len, dtype=torch.float32)
205
+
206
+ synth_win = torch.hann_window(chunk_samples, periodic=True, dtype=torch.float32)
207
+
208
+ idx = 0
209
+ while idx < T:
210
+ start = idx
211
+ end = min(idx + chunk_samples, T)
212
+ chunk = wav[start:end].unsqueeze(0).unsqueeze(0).to(device) # (1,1,L)
213
+ orig_len = chunk.shape[-1]
214
+ if orig_len < chunk_samples:
215
+ chunk = mirror_pad_last_chunk(chunk, chunk_samples).to(device)
216
+
217
+ with torch.no_grad():
218
+ spec = stft(chunk) # (1,1,F_spec,Frames)
219
+ fake = G(spec) # (1,1,F_fake,Frames)
220
+ mag = torch.expm1(fake.clamp_min(0.0)).squeeze(1) # (1,F_fake,Frames)
221
+
222
+ chunk_1d = chunk.view(1, -1)
223
+ complex_noisy = torch.stft(chunk_1d, n_fft=cfg.n_fft, hop_length=cfg.hop_length,
224
+ win_length=cfg.win_length, window=window, return_complex=True)
225
+ phase = torch.angle(complex_noisy) # (1,F_phase,Frames_phase)
226
+
227
+ n_frames_mag = mag.shape[-1]
228
+ n_frames_phase = phase.shape[-1]
229
+ min_frames = min(n_frames_mag, n_frames_phase)
230
+ mag = mag[..., :min_frames]
231
+ phase = phase[..., :min_frames]
232
+
233
+ expected_F = cfg.n_fft // 2 + 1
234
+ mag = pad_or_crop_freq(mag, expected_F)
235
+
236
+ real = mag * torch.cos(phase)
237
+ imag = mag * torch.sin(phase)
238
+ complex_spec = torch.complex(real, imag).squeeze(0) # (F, frames)
239
+
240
+ wav_rec = torch.istft(complex_spec.unsqueeze(0).to(device),
241
+ n_fft=cfg.n_fft, hop_length=cfg.hop_length,
242
+ win_length=cfg.win_length, window=window,
243
+ length=chunk_samples).squeeze(0).cpu()
244
+
245
+ if wav_rec.shape[-1] < chunk_samples:
246
+ wav_rec = F.pad(wav_rec, (0, chunk_samples - wav_rec.shape[-1]))
247
+ elif wav_rec.shape[-1] > chunk_samples:
248
+ wav_rec = wav_rec[:chunk_samples]
249
+
250
+ win = synth_win.clone().cpu()
251
+ wav_rec_win = wav_rec * win
252
+
253
+ write_start = start
254
+ write_end = start + chunk_samples
255
+ out_buffer[write_start:write_end] += wav_rec_win
256
+ weight_buffer[write_start:write_end] += win
257
+
258
+ idx += hop
259
+
260
+ nonzero = weight_buffer > 1e-8
261
+ out_buffer[nonzero] = out_buffer[nonzero] / weight_buffer[nonzero]
262
+ denoised = out_buffer[:T].contiguous()
263
+
264
+ if use_spectral_gate:
265
+ print("Applying final spectral gating...")
266
+ denoised = spectral_subtract_and_reconstruct(denoised.unsqueeze(0), stft, cfg,
267
+ noise_frac=noise_frac, subtract_strength=subtract_strength,
268
+ device=cfg.device)
269
+
270
+ denoised = torch.clamp(denoised, -0.999, 0.999)
271
+
272
+ # save denoised wav
273
+ robust_save(output_path, denoised, sr)
274
+
275
+ packed_png = None
276
+ recon_wav = None
277
+ if pack_png:
278
+ packed_png = os.path.splitext(output_path)[0] + "_packed.png"
279
+ save_audio_as_png_lossless(denoised, packed_png, width=png_width)
280
+ print("Packed denoised audio into PNG:", packed_png)
281
+ # optional: reconstruct to verify
282
+ recon_wav = os.path.splitext(output_path)[0] + "_reconstructed_from_png.wav"
283
+ restored = load_audio_from_png_lossless(packed_png, original_length=denoised.shape[-1])
284
+ write_wav_from_tensor(restored, recon_wav, sr)
285
+ print("Reconstructed WAV from PNG:", recon_wav)
286
+
287
+ return output_path, packed_png, recon_wav
288
+
289
+
290
+ # === Lossless audio <-> PNG packing (bit-perfect) ===
291
+
292
+ def audio_tensor_to_int16_array(wav_tensor: torch.Tensor):
293
+ if isinstance(wav_tensor, torch.Tensor):
294
+ x = wav_tensor.detach().cpu().numpy()
295
+ else:
296
+ x = np.asarray(wav_tensor)
297
+ if x.ndim == 2 and x.shape[0] == 1:
298
+ x = x[0]
299
+ x = np.clip(x, -1.0, 1.0)
300
+ int16 = (x * 32767.0).astype(np.int16)
301
+ return int16
302
+
303
+
304
+ def int16_array_to_audio_tensor(int16_arr: np.ndarray):
305
+ arr = np.asarray(int16_arr, dtype=np.int16)
306
+ float32 = (arr.astype(np.float32) / 32767.0)
307
+ return torch.from_numpy(float32)
308
+
309
+
310
+ def save_audio_as_png_lossless(wav_tensor: torch.Tensor, png_path: str, width: int = 2048):
311
+ samples = audio_tensor_to_int16_array(wav_tensor)
312
+ N = samples.shape[0]
313
+ height = math.ceil(N / width)
314
+ total = width * height
315
+ pad = total - N
316
+ padded = np.pad(samples, (0, pad), mode='constant', constant_values=0).astype(np.int16)
317
+
318
+ arr = padded.reshape((height, width))
319
+ uint16_view = arr.view(np.uint16)
320
+
321
+ im = Image.fromarray(uint16_view, mode='I;16')
322
+ os.makedirs(os.path.dirname(png_path), exist_ok=True)
323
+ im.save(png_path, format='PNG')
324
+ print(f"Saved lossless audio PNG: {png_path} (samples={N}, width={width}, height={height})")
325
+ return png_path
326
+
327
+
328
+ def load_audio_from_png_lossless(png_path: str, original_length: int = None):
329
+ im = Image.open(png_path)
330
+ arr_uint16 = np.array(im, dtype=np.uint16)
331
+ int16_arr = arr_uint16.view(np.int16).reshape(-1)
332
+ if original_length is not None:
333
+ int16_arr = int16_arr[:original_length]
334
+ float_tensor = int16_array_to_audio_tensor(int16_arr)
335
+ return float_tensor # 1D torch tensor
336
+
337
+
338
+ def write_wav_from_tensor(tensor: torch.Tensor, out_wav_path: str, sr: int):
339
+ x = tensor.detach().cpu().numpy()
340
+ int16 = (np.clip(x, -1.0, 1.0) * 32767.0).astype(np.int16)
341
+ os.makedirs(os.path.dirname(out_wav_path), exist_ok=True)
342
+ with wave.open(out_wav_path, 'wb') as wf:
343
+ wf.setnchannels(1)
344
+ wf.setsampwidth(2)
345
+ wf.setframerate(sr)
346
+ wf.writeframes(int16.tobytes())
347
+ print(f"WAV written (lossless restore): {out_wav_path} (samples={int16.size}, sr={sr})")
348
+ return out_wav_path
349
+
350
+
351
+ # ----------------- CLI for quick local test -----------------
352
+ if __name__ == '__main__':
353
+ import argparse
354
+
355
+ parser = argparse.ArgumentParser(description='Denoise WAV and pack into PNG (pipeline module)')
356
+ parser.add_argument('--input', '-i', required=True, help='Input WAV file path')
357
+ parser.add_argument('--output', '-o', required=False, help='Output denoised WAV path (default: input_den.wav)')
358
+ parser.add_argument('--checkpoint', '-c', required=False, help='Checkpoint path')
359
+ parser.add_argument('--png-width', type=int, default=2048)
360
+ parser.add_argument('--no-pack', dest='pack', action='store_false')
361
+ parser.set_defaults(pack=True)
362
+
363
+ args = parser.parse_args()
364
+
365
+ inp = args.input
366
+ out = args.output or os.path.splitext(inp)[0] + '_denoised.wav'
367
+ ckpt = args.checkpoint or DEFAULT_CHECKPOINT
368
+ cfg = InferConfig(ckpt_path=ckpt)
369
+
370
+ print('Running pipeline...')
371
+ try:
372
+ out_wav, packed_png, recon = denoise_chunked_final(inp, out, cfg, chunk_seconds=50.0, overlap=0.5,
373
+ use_spectral_gate=True, noise_frac=0.1, subtract_strength=1.0,
374
+ pack_png=args.pack, png_width=args.png_width)
375
+ print('Done.\n', out_wav, packed_png, recon)
376
+ except Exception as e:
377
+ print('Pipeline error:', e)
378
+ raise
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ torch==2.1.0+cpu
4
+ torchaudio==2.1.0+cpu
5
+ numpy
6
+ pillow
7
+ requests