91prince commited on
Commit
6ddde83
·
verified ·
1 Parent(s): 2b307a7

Upload 7 files

Browse files
Files changed (7) hide show
  1. api.py +41 -0
  2. app.py +57 -0
  3. best_denoiser_model.pth +3 -0
  4. inference.py +146 -0
  5. model_def.py +79 -0
  6. requirement.txt +8 -0
  7. test_api_client.py +31 -0
api.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api.py
2
+ from fastapi import FastAPI, UploadFile, File, Response
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from inference import denoise_file_bytes # this must exist in inference.py
5
+
6
+ app = FastAPI(title="Audio Denoiser API")
7
+
8
+ # Optional CORS (handy if you later call from a frontend)
9
+ app.add_middleware(
10
+ CORSMiddleware,
11
+ allow_origins=["*"], # tighten later if needed
12
+ allow_credentials=True,
13
+ allow_methods=["*"],
14
+ allow_headers=["*"],
15
+ )
16
+
17
+ @app.get("/")
18
+ async def root():
19
+ return {"message": "Audio Denoiser API is running"}
20
+
21
+ # IMPORTANT: this is POST, not GET
22
+ @app.post("/denoise")
23
+ async def denoise_endpoint(file: UploadFile = File(...)):
24
+ """
25
+ Upload a noisy audio file (wav), get back denoised audio bytes.
26
+ """
27
+ # Read uploaded file into bytes
28
+ contents = await file.read()
29
+
30
+ # Call your model inference – must return raw WAV bytes
31
+ denoised_bytes = denoise_file_bytes(contents)
32
+
33
+ # Return as an audio/wav HTTP response
34
+ return Response(
35
+ content=denoised_bytes,
36
+ media_type="audio/wav",
37
+ headers={
38
+ # Makes browser / client see it as downloadable file
39
+ "Content-Disposition": f'attachment; filename="denoised_{file.filename}"'
40
+ },
41
+ )
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import torch
4
+ import torchaudio
5
+
6
+ from inference import denoise_waveform_tensor
7
+
8
+ TITLE = "Advanced Audio Denoiser (Spectrogram U-Net)"
9
+ DESCRIPTION = """
10
+ Upload a noisy WAV/MP3 audio file and the model will try to remove background noise.
11
+ This Space uses a ResUNet-based spectrogram denoiser trained by Prince.
12
+ """
13
+
14
+ EXAMPLES = [] # You can add paths to example audio files if you upload some
15
+
16
+
17
+ def denoise_gradio(input_audio):
18
+ """
19
+ input_audio: (filepath, sr) from Gradio
20
+ """
21
+ if input_audio is None:
22
+ return None
23
+
24
+ # Gradio passes (np.array, sr) by default with type="numpy"
25
+ waveform_np, sr = input_audio
26
+ # Convert to torch
27
+ waveform = torch.from_numpy(waveform_np).float().transpose(0, 1) # (T, C) -> (C, T)
28
+ waveform = waveform.unsqueeze(0) if waveform.dim() == 1 else waveform
29
+
30
+ denoised, out_sr = denoise_waveform_tensor(waveform, sr) # (1, T)
31
+ denoised_np = denoised.squeeze(0).numpy()
32
+
33
+ # Gradio expects (sr, np.array[T,])
34
+ return (out_sr, denoised_np)
35
+
36
+
37
+ with gr.Blocks() as demo:
38
+ gr.Markdown(f"# {TITLE}")
39
+ gr.Markdown(DESCRIPTION)
40
+
41
+ with gr.Row():
42
+ with gr.Column():
43
+ inp = gr.Audio(
44
+ sources=["upload"],
45
+ type="numpy",
46
+ label="Upload noisy audio",
47
+ )
48
+ btn = gr.Button("Denoise")
49
+ with gr.Column():
50
+ out = gr.Audio(
51
+ type="numpy",
52
+ label="Denoised audio",
53
+ )
54
+
55
+ btn.click(denoise_gradio, inputs=inp, outputs=out)
56
+
57
+ demo.launch()
best_denoiser_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a39212b3737ab29c4eb26aa8642d3a383919de24c945365d8dc17d16e51664b
3
+ size 8132784
inference.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import io
3
+ import torch
4
+ import torchaudio
5
+
6
+ from model_def import AdvancedResUNet
7
+
8
+ CONFIG = {
9
+ "sample_rate": 16000,
10
+ "n_fft": 1024,
11
+ "hop_length": 256,
12
+ "n_mels": 80,
13
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
14
+ "model_path": "best_denoiser_model.pth", # put this file in the Space
15
+ }
16
+
17
+ _model = None
18
+ _mel_scale = None
19
+ _inverse_mel = None
20
+ _window = None
21
+
22
+
23
+ def _get_device():
24
+ return CONFIG["device"]
25
+
26
+
27
+ def load_model():
28
+ global _model, _mel_scale, _inverse_mel, _window
29
+
30
+ if _model is not None:
31
+ return _model
32
+
33
+ device = _get_device()
34
+
35
+ model = AdvancedResUNet().to(device)
36
+ state_dict = torch.load(CONFIG["model_path"], map_location=device)
37
+ model.load_state_dict(state_dict)
38
+ model.eval()
39
+
40
+ # Mel + Inverse Mel + window
41
+ _mel_scale = torchaudio.transforms.MelScale(
42
+ n_mels=CONFIG["n_mels"],
43
+ sample_rate=CONFIG["sample_rate"],
44
+ n_stft=CONFIG["n_fft"] // 2 + 1,
45
+ ).to(device)
46
+
47
+ _inverse_mel = torchaudio.transforms.InverseMelScale(
48
+ n_stft=CONFIG["n_fft"] // 2 + 1,
49
+ n_mels=CONFIG["n_mels"],
50
+ sample_rate=CONFIG["sample_rate"],
51
+ ).to(device)
52
+
53
+ _window = torch.hann_window(CONFIG["n_fft"]).to(device)
54
+
55
+ _model = model
56
+ print(f"[inference] Model loaded on {device}")
57
+ return _model
58
+
59
+
60
+ def _normalize_waveform(waveform: torch.Tensor) -> torch.Tensor:
61
+ max_val = waveform.abs().max()
62
+ if max_val > 0:
63
+ waveform = waveform / max_val
64
+ return waveform
65
+
66
+
67
+ def denoise_waveform_tensor(waveform: torch.Tensor, sr: int) -> torch.Tensor:
68
+ """
69
+ waveform: Tensor of shape (1, T) on CPU
70
+ returns: denoised waveform Tensor (1, T) on CPU
71
+ """
72
+ device = _get_device()
73
+ model = load_model()
74
+
75
+ # Ensure mono
76
+ if waveform.dim() == 2 and waveform.size(0) > 1:
77
+ waveform = waveform.mean(dim=0, keepdim=True)
78
+
79
+ # Resample if needed
80
+ if sr != CONFIG["sample_rate"]:
81
+ resampler = torchaudio.transforms.Resample(sr, CONFIG["sample_rate"])
82
+ waveform = resampler(waveform)
83
+
84
+ waveform = _normalize_waveform(waveform)
85
+ waveform = waveform.to(device)
86
+
87
+ global _mel_scale, _inverse_mel, _window
88
+
89
+ # --- STFT: get magnitude and phase ---
90
+ stft_complex = torch.stft(
91
+ waveform,
92
+ n_fft=CONFIG["n_fft"],
93
+ hop_length=CONFIG["hop_length"],
94
+ window=_window,
95
+ return_complex=True,
96
+ ) # (1, n_freq, n_frames)
97
+
98
+ noisy_phase = torch.angle(stft_complex)
99
+ noisy_mag = torch.abs(stft_complex) # (1, n_freq, n_frames)
100
+
101
+ # Mel wants (batch, n_freq, time) -> we already have that
102
+ noisy_mel = _mel_scale(noisy_mag) # (1, n_mels, n_frames)
103
+ noisy_log_mel = torch.log1p(noisy_mel + 1e-6)
104
+
105
+ # Model expects (B, 1, n_mels, T)
106
+ noisy_log_mel = noisy_log_mel.unsqueeze(1) # (1, 1, n_mels, n_frames)
107
+
108
+ with torch.no_grad():
109
+ denoised_log_mel = model(noisy_log_mel) # (1, 1, n_mels, n_frames)
110
+ denoised_log_mel = denoised_log_mel.squeeze(1) # (1, n_mels, n_frames)
111
+
112
+ denoised_mel = torch.expm1(denoised_log_mel)
113
+ denoised_mel = torch.clamp(denoised_mel, min=0.0)
114
+
115
+ # Back to linear spectrogram magnitude
116
+ pred_mag = _inverse_mel(denoised_mel) # (1, n_freq, n_frames)
117
+
118
+ # Combine predicted magnitude with original phase
119
+ complex_pred = pred_mag * torch.exp(1j * noisy_phase)
120
+
121
+ rec_waveform = torch.istft(
122
+ complex_pred,
123
+ n_fft=CONFIG["n_fft"],
124
+ hop_length=CONFIG["hop_length"],
125
+ window=_window,
126
+ length=waveform.shape[-1],
127
+ ) # (1, T) or (T,)
128
+
129
+ if rec_waveform.dim() == 1:
130
+ rec_waveform = rec_waveform.unsqueeze(0)
131
+
132
+ rec_waveform = _normalize_waveform(rec_waveform.cpu())
133
+ return rec_waveform # (1, T) CPU
134
+
135
+
136
+ def denoise_file_bytes(file_bytes: bytes):
137
+ """
138
+ For API / Gradio: takes input bytes, returns (waveform, sample_rate)
139
+ """
140
+ buf = io.BytesIO(file_bytes)
141
+ waveform, sr = torchaudio.load(buf) # (channels, T), CPU
142
+ if waveform.dim() == 1:
143
+ waveform = waveform.unsqueeze(0)
144
+
145
+ denoised = denoise_waveform_tensor(waveform, sr) # (1, T) CPU
146
+ return denoised, CONFIG["sample_rate"]
model_def.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_def.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_channels, out_channels):
8
+ super().__init__()
9
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
10
+ self.bn1 = nn.BatchNorm2d(out_channels)
11
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
12
+ self.bn2 = nn.BatchNorm2d(out_channels)
13
+
14
+ self.shortcut = nn.Sequential()
15
+ if in_channels != out_channels:
16
+ self.shortcut = nn.Sequential(
17
+ nn.Conv2d(in_channels, out_channels, kernel_size=1),
18
+ nn.BatchNorm2d(out_channels)
19
+ )
20
+
21
+ def forward(self, x):
22
+ residual = self.shortcut(x)
23
+ out = F.relu(self.bn1(self.conv1(x)))
24
+ out = self.bn2(self.conv2(out))
25
+ out += residual
26
+ return F.relu(out)
27
+
28
+ class AdvancedResUNet(nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ # Encoder
32
+ self.enc1 = ResidualBlock(1, 32)
33
+ self.pool1 = nn.MaxPool2d(2)
34
+ self.enc2 = ResidualBlock(32, 64)
35
+ self.pool2 = nn.MaxPool2d(2)
36
+ self.enc3 = ResidualBlock(64, 128)
37
+ self.pool3 = nn.MaxPool2d(2)
38
+ # Bottleneck
39
+ self.bottleneck = ResidualBlock(128, 256)
40
+ # Decoder
41
+ self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
42
+ self.dec3 = ResidualBlock(256, 128)
43
+ self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
44
+ self.dec2 = ResidualBlock(128, 64)
45
+ self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
46
+ self.dec1 = ResidualBlock(64, 32)
47
+
48
+ self.final_conv = nn.Conv2d(32, 1, kernel_size=1)
49
+ self.sigmoid = nn.Sigmoid()
50
+
51
+ def forward(self, x):
52
+ e1 = self.enc1(x)
53
+ p1 = self.pool1(e1)
54
+ e2 = self.enc2(p1)
55
+ p2 = self.pool2(e2)
56
+ e3 = self.enc3(p2)
57
+ p3 = self.pool3(e3)
58
+ b = self.bottleneck(p3)
59
+
60
+ d3 = self.up3(b)
61
+ if d3.shape != e3.shape:
62
+ d3 = F.interpolate(d3, size=e3.shape[2:])
63
+ d3 = torch.cat([d3, e3], dim=1)
64
+ d3 = self.dec3(d3)
65
+
66
+ d2 = self.up2(d3)
67
+ if d2.shape != e2.shape:
68
+ d2 = F.interpolate(d2, size=e2.shape[2:])
69
+ d2 = torch.cat([d2, e2], dim=1)
70
+ d2 = self.dec2(d2)
71
+
72
+ d1 = self.up1(d2)
73
+ if d1.shape != e1.shape:
74
+ d1 = F.interpolate(d1, size=e1.shape[2:])
75
+ d1 = torch.cat([d1, e1], dim=1)
76
+ d1 = self.dec1(d1)
77
+
78
+ mask = self.sigmoid(self.final_conv(d1))
79
+ return x * mask
requirement.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ install python modules
2
+
3
+ fastapi
4
+ uvicorn
5
+ torch
6
+ torchaudio
7
+ requests
8
+ gradio
test_api_client.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_api_client.py
2
+ import requests
3
+ import os
4
+
5
+ API_URL = "http://127.0.0.1:8000/denoise"
6
+
7
+ # path to one noisy test file
8
+ INPUT_WAV = r"E:\Test Audio Data\Test Audio Data\example_noisy.wav"
9
+ OUTPUT_WAV = r"E:\Test Audio Data\Test Audio Data\example_noisy_denoised_from_api.wav"
10
+
11
+ def test_denoise():
12
+ if not os.path.exists(INPUT_WAV):
13
+ print("Input file not found:", INPUT_WAV)
14
+ return
15
+
16
+ with open(INPUT_WAV, "rb") as f:
17
+ files = {"file": ("example_noisy.wav", f, "audio/wav")}
18
+ resp = requests.post(API_URL, files=files)
19
+
20
+ print("Status code:", resp.status_code)
21
+
22
+ if resp.status_code == 200:
23
+ with open(OUTPUT_WAV, "wb") as out_f:
24
+ out_f.write(resp.content)
25
+ print("Saved denoised file to:", OUTPUT_WAV)
26
+ else:
27
+ print("Error response body:")
28
+ print(resp.text)
29
+
30
+ if __name__ == "__main__":
31
+ test_denoise()