Spaces:
Runtime error
Runtime error
Commit
·
bd0a813
1
Parent(s):
3f204d4
fixes
Browse files- .gitignore +1 -0
- EDA.ipynb +0 -0
- README.md +4 -4
- app.py +9 -7
- datasets.py +39 -0
- denoisers/SpectralGating.py +1 -1
- denoisers/__pycache__/SpectralGating.cpython-38.pyc +0 -0
- denoisers/demucs.py +33 -24
- evaluation.py +2 -2
- metrics.py +5 -4
- train.py +121 -0
.gitignore
CHANGED
|
@@ -2,3 +2,4 @@
|
|
| 2 |
.ipynb_checkpoints/**
|
| 3 |
nohup.out
|
| 4 |
__pycache__/**
|
|
|
|
|
|
| 2 |
.ipynb_checkpoints/**
|
| 3 |
nohup.out
|
| 4 |
__pycache__/**
|
| 5 |
+
cache_wav/
|
EDA.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
README.md
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
|
| 4 |
# Testing
|
| 5 |
-
| |
|
| 6 |
-
|
| 7 |
-
| ideal denoising |
|
| 8 |
-
| baseline |
|
| 9 |
|
| 10 |
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
# Testing
|
| 5 |
+
| | valentini_PESQ | valentini_STOI |
|
| 6 |
+
|:---------------:|:--------------:|:--------------:|
|
| 7 |
+
| ideal denoising | 1.9709 | 0.9211 |
|
| 8 |
+
| baseline | 1.7433 | 0.8844 |
|
| 9 |
|
| 10 |
|
app.py
CHANGED
|
@@ -9,35 +9,37 @@ import logging
|
|
| 9 |
import gradio as gr
|
| 10 |
|
| 11 |
|
|
|
|
| 12 |
from denoisers.SpectralGating import SpectralGating
|
| 13 |
|
| 14 |
model = SpectralGating()
|
| 15 |
|
| 16 |
|
| 17 |
def denoising_transform(audio):
|
| 18 |
-
src_path = "cache_wav/
|
| 19 |
-
tgt_path = "cache_wav/
|
| 20 |
-
# os.rename(audio.name, src_path)
|
| 21 |
(ffmpeg.input(audio)
|
| 22 |
.output(src_path, acodec='pcm_s16le', ac=1, ar=22050)
|
| 23 |
.run()
|
| 24 |
)
|
| 25 |
-
|
| 26 |
-
model.predict(src_path, tgt_path)
|
| 27 |
return tgt_path
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
inputs = gr.inputs.Audio(label="Source Audio", source="microphone", type='filepath')
|
| 31 |
outputs = gr.outputs.Audio(label="Target Audio", type='filepath')
|
| 32 |
|
| 33 |
-
title = "
|
| 34 |
#"""
|
| 35 |
gr.Interface(
|
| 36 |
denoising_transform, inputs, outputs, title=title,
|
| 37 |
-
allow_flagging='never'
|
| 38 |
).launch(
|
| 39 |
server_name='localhost',
|
| 40 |
server_port=7871,
|
|
|
|
| 41 |
#ssl_keyfile='example.key',
|
| 42 |
#ssl_certfile="example.crt",
|
| 43 |
)
|
|
|
|
| 9 |
import gradio as gr
|
| 10 |
|
| 11 |
|
| 12 |
+
|
| 13 |
from denoisers.SpectralGating import SpectralGating
|
| 14 |
|
| 15 |
model = SpectralGating()
|
| 16 |
|
| 17 |
|
| 18 |
def denoising_transform(audio):
|
| 19 |
+
src_path = "cache_wav/original/{}.wav".format(str(uuid.uuid4()))
|
| 20 |
+
tgt_path = "cache_wav/denoised/{}.wav".format(str(uuid.uuid4()))
|
|
|
|
| 21 |
(ffmpeg.input(audio)
|
| 22 |
.output(src_path, acodec='pcm_s16le', ac=1, ar=22050)
|
| 23 |
.run()
|
| 24 |
)
|
| 25 |
+
model.predict(audio, tgt_path)
|
|
|
|
| 26 |
return tgt_path
|
| 27 |
+
# model.predict(src_path, tgt_path)
|
| 28 |
+
# return tgt_path
|
| 29 |
|
| 30 |
|
| 31 |
inputs = gr.inputs.Audio(label="Source Audio", source="microphone", type='filepath')
|
| 32 |
outputs = gr.outputs.Audio(label="Target Audio", type='filepath')
|
| 33 |
|
| 34 |
+
title = "Denoising"
|
| 35 |
#"""
|
| 36 |
gr.Interface(
|
| 37 |
denoising_transform, inputs, outputs, title=title,
|
| 38 |
+
allow_flagging='never'
|
| 39 |
).launch(
|
| 40 |
server_name='localhost',
|
| 41 |
server_port=7871,
|
| 42 |
+
share=True
|
| 43 |
#ssl_keyfile='example.key',
|
| 44 |
#ssl_certfile="example.crt",
|
| 45 |
)
|
datasets.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from utils import load_wav
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Valentini(Dataset):
|
| 8 |
+
def __init__(self, dataset_path='/media/public/datasets/denoising/DS_10283_2791/', transform=None,
|
| 9 |
+
valid=False):
|
| 10 |
+
clean_path = Path(dataset_path) / 'clean_trainset_56spk_wav'
|
| 11 |
+
noisy_path = Path(dataset_path) / 'noisy_trainset_56spk_wav'
|
| 12 |
+
clean_wavs = list(clean_path.glob("*"))
|
| 13 |
+
noisy_wavs = list(noisy_path.glob("*"))
|
| 14 |
+
valid_threshold = int(len(clean_wavs) * 0.2)
|
| 15 |
+
if valid:
|
| 16 |
+
self.clean_wavs = clean_wavs[:valid_threshold]
|
| 17 |
+
self.noisy_wavs = noisy_wavs[:valid_threshold]
|
| 18 |
+
else:
|
| 19 |
+
self.clean_wavs = clean_wavs[valid_threshold:]
|
| 20 |
+
self.noisy_wavs = noisy_wavs[valid_threshold:]
|
| 21 |
+
|
| 22 |
+
assert len(self.clean_wavs) == len(self.noisy_wavs)
|
| 23 |
+
|
| 24 |
+
self.transform = transform
|
| 25 |
+
|
| 26 |
+
def __len__(self):
|
| 27 |
+
return len(self.clean_wavs)
|
| 28 |
+
|
| 29 |
+
def __getitem__(self, idx):
|
| 30 |
+
noisy_wav = load_wav(self.noisy_wavs[idx])
|
| 31 |
+
clean_wav = load_wav(self.clean_wavs[idx])
|
| 32 |
+
|
| 33 |
+
if self.transform:
|
| 34 |
+
random_seed = torch.randint(100, (1,))[0]
|
| 35 |
+
torch.manual_seed(random_seed)
|
| 36 |
+
noisy_wav = self.transform(noisy_wav)
|
| 37 |
+
torch.manual_seed(random_seed)
|
| 38 |
+
clean_wav = self.transform(clean_wav)
|
| 39 |
+
return noisy_wav, clean_wav
|
denoisers/SpectralGating.py
CHANGED
|
@@ -16,7 +16,7 @@ class SpectralGating(torch.nn.Module):
|
|
| 16 |
data, rate = torchaudio.load(wav_path)
|
| 17 |
reduced_noise = torch.Tensor(nr.reduce_noise(y=data, sr=rate))
|
| 18 |
torchaudio.save(out_path, reduced_noise, rate)
|
| 19 |
-
return
|
| 20 |
|
| 21 |
|
| 22 |
|
|
|
|
| 16 |
data, rate = torchaudio.load(wav_path)
|
| 17 |
reduced_noise = torch.Tensor(nr.reduce_noise(y=data, sr=rate))
|
| 18 |
torchaudio.save(out_path, reduced_noise, rate)
|
| 19 |
+
return out_path
|
| 20 |
|
| 21 |
|
| 22 |
|
denoisers/__pycache__/SpectralGating.cpython-38.pyc
ADDED
|
Binary file (1.08 kB). View file
|
|
|
denoisers/demucs.py
CHANGED
|
@@ -1,36 +1,34 @@
|
|
| 1 |
import torch
|
| 2 |
-
|
| 3 |
|
| 4 |
class Encoder(torch.nn.Module):
|
| 5 |
-
def __init__(self, in_channels, out_channels
|
| 6 |
-
kernel_size_1=8, stride_1=4,
|
| 7 |
-
kernel_size_2=1, stride_2=1):
|
| 8 |
super(Encoder, self).__init__()
|
| 9 |
|
| 10 |
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
|
| 11 |
-
kernel_size=
|
| 12 |
self.relu1 = torch.nn.ReLU()
|
| 13 |
self.conv2 = torch.nn.Conv1d(in_channels=out_channels, out_channels=2 * out_channels,
|
| 14 |
-
kernel_size=
|
| 15 |
-
self.glu = torch.nn.GLU()
|
| 16 |
|
| 17 |
def forward(self, x):
|
| 18 |
x = self.relu1(self.conv1(x))
|
|
|
|
|
|
|
| 19 |
x = self.glu(self.conv2(x))
|
| 20 |
return x
|
| 21 |
|
| 22 |
|
| 23 |
class Decoder(torch.nn.Module):
|
| 24 |
-
def __init__(self, in_channels, out_channels
|
| 25 |
-
kernel_size_1=3, stride_1=1,
|
| 26 |
-
kernel_size_2=8, stride_2=4):
|
| 27 |
super(Decoder, self).__init__()
|
| 28 |
|
| 29 |
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
|
| 30 |
-
kernel_size=
|
| 31 |
-
self.glu = torch.nn.GLU()
|
| 32 |
self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
|
| 33 |
-
kernel_size=
|
| 34 |
self.relu = torch.nn.ReLU()
|
| 35 |
|
| 36 |
def forward(self, x):
|
|
@@ -40,28 +38,39 @@ class Decoder(torch.nn.Module):
|
|
| 40 |
|
| 41 |
|
| 42 |
class Demucs(torch.nn.Module):
|
| 43 |
-
def __init__(self):
|
| 44 |
super(Demucs, self).__init__()
|
| 45 |
|
| 46 |
-
self.encoder1 = Encoder(in_channels=1, out_channels=
|
| 47 |
-
self.encoder2 = Encoder(in_channels=
|
| 48 |
-
self.encoder3 = Encoder(in_channels=
|
| 49 |
|
| 50 |
-
self.lstm = torch.nn.LSTM(
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
self.decoder1 = Decoder(in_channels=
|
| 53 |
-
self.decoder2 = Decoder(in_channels=
|
| 54 |
-
self.decoder3 = Decoder(in_channels=
|
| 55 |
|
| 56 |
def forward(self, x):
|
| 57 |
out1 = self.encoder1(x)
|
| 58 |
out2 = self.encoder2(out1)
|
| 59 |
out3 = self.encoder3(out2)
|
| 60 |
|
| 61 |
-
x = self.lstm(out3)
|
| 62 |
-
|
| 63 |
x = self.decoder1(x + out3)
|
|
|
|
| 64 |
x = self.decoder2(x + out2)
|
| 65 |
-
x =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
|
|
|
| 67 |
return x
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from torch.nn.functional import pad
|
| 3 |
|
| 4 |
class Encoder(torch.nn.Module):
|
| 5 |
+
def __init__(self, in_channels, out_channels):
|
|
|
|
|
|
|
| 6 |
super(Encoder, self).__init__()
|
| 7 |
|
| 8 |
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
|
| 9 |
+
kernel_size=8, stride=2)
|
| 10 |
self.relu1 = torch.nn.ReLU()
|
| 11 |
self.conv2 = torch.nn.Conv1d(in_channels=out_channels, out_channels=2 * out_channels,
|
| 12 |
+
kernel_size=1, stride=1)
|
| 13 |
+
self.glu = torch.nn.GLU(dim=-2)
|
| 14 |
|
| 15 |
def forward(self, x):
|
| 16 |
x = self.relu1(self.conv1(x))
|
| 17 |
+
if x.shape[-1] % 2 == 1:
|
| 18 |
+
x = pad(x, (0, 1))
|
| 19 |
x = self.glu(self.conv2(x))
|
| 20 |
return x
|
| 21 |
|
| 22 |
|
| 23 |
class Decoder(torch.nn.Module):
|
| 24 |
+
def __init__(self, in_channels, out_channels):
|
|
|
|
|
|
|
| 25 |
super(Decoder, self).__init__()
|
| 26 |
|
| 27 |
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
|
| 28 |
+
kernel_size=1, stride=1)
|
| 29 |
+
self.glu = torch.nn.GLU(dim=-2)
|
| 30 |
self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
|
| 31 |
+
kernel_size=8, stride=2)
|
| 32 |
self.relu = torch.nn.ReLU()
|
| 33 |
|
| 34 |
def forward(self, x):
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
class Demucs(torch.nn.Module):
|
| 41 |
+
def __init__(self, H):
|
| 42 |
super(Demucs, self).__init__()
|
| 43 |
|
| 44 |
+
self.encoder1 = Encoder(in_channels=1, out_channels=H)
|
| 45 |
+
self.encoder2 = Encoder(in_channels=H, out_channels=2*H)
|
| 46 |
+
self.encoder3 = Encoder(in_channels=2*H, out_channels=4*H)
|
| 47 |
|
| 48 |
+
self.lstm = torch.nn.LSTM(
|
| 49 |
+
input_size=4*H,
|
| 50 |
+
hidden_size=4*H, num_layers=2, batch_first=True)
|
| 51 |
|
| 52 |
+
self.decoder1 = Decoder(in_channels=4*H, out_channels=2*H)
|
| 53 |
+
self.decoder2 = Decoder(in_channels=2*H, out_channels=H)
|
| 54 |
+
self.decoder3 = Decoder(in_channels=H, out_channels=1)
|
| 55 |
|
| 56 |
def forward(self, x):
|
| 57 |
out1 = self.encoder1(x)
|
| 58 |
out2 = self.encoder2(out1)
|
| 59 |
out3 = self.encoder3(out2)
|
| 60 |
|
| 61 |
+
x, _ = self.lstm(out3.permute(0, 2, 1))
|
| 62 |
+
x = x.permute(0, 2, 1)
|
| 63 |
x = self.decoder1(x + out3)
|
| 64 |
+
x = x[:, :, :out2.shape[-1]]
|
| 65 |
x = self.decoder2(x + out2)
|
| 66 |
+
x = x[:, :, :-1]
|
| 67 |
+
out1 = out1[:, :, :-1]
|
| 68 |
+
if x.shape[-1] > out1.shape[-1]:
|
| 69 |
+
x = x[:, :, :out1.shape[-1]]
|
| 70 |
+
elif x.shape[-1] < out1.shape[-1]:
|
| 71 |
+
out1 = out1[:, :, :x.shape[-1]]
|
| 72 |
|
| 73 |
+
x = self.decoder3(x + out1)
|
| 74 |
return x
|
| 75 |
+
|
| 76 |
+
|
evaluation.py
CHANGED
|
@@ -28,10 +28,10 @@ def evaluate_on_dataset(model_name, dataset_path, dataset_type):
|
|
| 28 |
noisy_wav = load_wav(noisy_path)
|
| 29 |
|
| 30 |
if model_name is None:
|
| 31 |
-
scores = metrics.calculate(noisy_wav, clean_wav)
|
| 32 |
else:
|
| 33 |
denoised_wav = model(noisy_wav)
|
| 34 |
-
scores = metrics.calculate(
|
| 35 |
|
| 36 |
mean_scores['PESQ'] += scores['PESQ']
|
| 37 |
mean_scores['STOI'] += scores['STOI']
|
|
|
|
| 28 |
noisy_wav = load_wav(noisy_path)
|
| 29 |
|
| 30 |
if model_name is None:
|
| 31 |
+
scores = metrics.calculate(denoised=noisy_wav, clean=clean_wav)
|
| 32 |
else:
|
| 33 |
denoised_wav = model(noisy_wav)
|
| 34 |
+
scores = metrics.calculate(denoised=denoised_wav, clean=clean_wav)
|
| 35 |
|
| 36 |
mean_scores['PESQ'] += scores['PESQ']
|
| 37 |
mean_scores['STOI'] += scores['STOI']
|
metrics.py
CHANGED
|
@@ -2,16 +2,17 @@ from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
|
| 2 |
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
| 3 |
import torch
|
| 4 |
import torchaudio
|
| 5 |
-
|
| 6 |
|
| 7 |
|
| 8 |
class Metrics:
|
| 9 |
def __init__(self, rate=16000):
|
| 10 |
self.nb_pesq = PerceptualEvaluationSpeechQuality(rate, 'wb')
|
| 11 |
self.stoi = ShortTimeObjectiveIntelligibility(rate, False)
|
|
|
|
| 12 |
|
| 13 |
-
def calculate(self,
|
| 14 |
-
return {'PESQ': self.nb_pesq(
|
| 15 |
-
'STOI': self.stoi(
|
| 16 |
|
| 17 |
|
|
|
|
| 2 |
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
| 3 |
import torch
|
| 4 |
import torchaudio
|
| 5 |
+
from torchmetrics import SignalNoiseRatio
|
| 6 |
|
| 7 |
|
| 8 |
class Metrics:
|
| 9 |
def __init__(self, rate=16000):
|
| 10 |
self.nb_pesq = PerceptualEvaluationSpeechQuality(rate, 'wb')
|
| 11 |
self.stoi = ShortTimeObjectiveIntelligibility(rate, False)
|
| 12 |
+
self.snr = SignalNoiseRatio()
|
| 13 |
|
| 14 |
+
def calculate(self, denoised, clean):
|
| 15 |
+
return {'PESQ': self.nb_pesq(denoised, clean),
|
| 16 |
+
'STOI': self.stoi(denoised, clean)}
|
| 17 |
|
| 18 |
|
train.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 3 |
+
import torch
|
| 4 |
+
from torch.nn import Sequential
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from datasets import Valentini
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from torchvision.transforms import RandomCrop
|
| 9 |
+
from utils import load_wav
|
| 10 |
+
from denoisers.demucs import Demucs
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
|
| 14 |
+
|
| 15 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 16 |
+
|
| 17 |
+
model = Demucs(H=64).to(device)
|
| 18 |
+
|
| 19 |
+
DATASET_PATH = Path('/media/public/datasets/denoising/DS_10283_2791/')
|
| 20 |
+
VALID_WAVS = {'hard': 'p257_171.wav',
|
| 21 |
+
'medium': 'p232_071.wav',
|
| 22 |
+
'easy': 'p232_284.wav'}
|
| 23 |
+
MAX_SECONDS = 3.2
|
| 24 |
+
SAMPLE_RATE = 16000
|
| 25 |
+
|
| 26 |
+
transform = Sequential(RandomCrop((1, int(MAX_SECONDS * SAMPLE_RATE)), pad_if_needed=True))
|
| 27 |
+
|
| 28 |
+
training_loader = DataLoader(Valentini(valid=False, transform=transform), batch_size=12, shuffle=True)
|
| 29 |
+
validation_loader = DataLoader(Valentini(valid=True, transform=transform), batch_size=12, shuffle=True)
|
| 30 |
+
|
| 31 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
|
| 32 |
+
loss_fn = torch.nn.MSELoss()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def train_one_epoch(epoch_index, tb_writer):
|
| 36 |
+
running_loss = 0.
|
| 37 |
+
last_loss = 0.
|
| 38 |
+
|
| 39 |
+
for i, data in enumerate(training_loader):
|
| 40 |
+
inputs, labels = data
|
| 41 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
| 42 |
+
|
| 43 |
+
optimizer.zero_grad()
|
| 44 |
+
|
| 45 |
+
outputs = model(inputs)
|
| 46 |
+
|
| 47 |
+
loss = loss_fn(outputs, labels)
|
| 48 |
+
loss.backward()
|
| 49 |
+
|
| 50 |
+
optimizer.step()
|
| 51 |
+
|
| 52 |
+
running_loss += loss.item()
|
| 53 |
+
if i % 1000 == 999:
|
| 54 |
+
last_loss = running_loss / 100 # loss per batch
|
| 55 |
+
print(' batch {} loss: {}'.format(i + 1, last_loss))
|
| 56 |
+
tb_x = epoch_index * len(training_loader) + i + 1
|
| 57 |
+
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
|
| 58 |
+
running_loss = 0.
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
return last_loss
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def train():
|
| 65 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 66 |
+
writer = SummaryWriter('runs/denoising_trainer_{}'.format(timestamp))
|
| 67 |
+
epoch_number = 0
|
| 68 |
+
|
| 69 |
+
EPOCHS = 5
|
| 70 |
+
|
| 71 |
+
best_vloss = 1_000_000.
|
| 72 |
+
|
| 73 |
+
for tag, wav_path in VALID_WAVS.items():
|
| 74 |
+
wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
|
| 75 |
+
writer.add_audio(tag=tag, snd_tensor=wav, sample_rate=SAMPLE_RATE)
|
| 76 |
+
writer.flush()
|
| 77 |
+
|
| 78 |
+
for epoch in range(EPOCHS):
|
| 79 |
+
print('EPOCH {}:'.format(epoch_number + 1))
|
| 80 |
+
|
| 81 |
+
# Make sure gradient tracking is on, and do a pass over the data
|
| 82 |
+
model.train(True)
|
| 83 |
+
avg_loss = train_one_epoch(epoch_number, writer)
|
| 84 |
+
|
| 85 |
+
# We don't need gradients on to do reporting
|
| 86 |
+
model.train(False)
|
| 87 |
+
|
| 88 |
+
running_vloss = 0.0
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
for i, vdata in enumerate(validation_loader):
|
| 91 |
+
vinputs, vlabels = vdata
|
| 92 |
+
vinputs, vlabels = vinputs.to(device), vlabels.to(device)
|
| 93 |
+
voutputs = model(vinputs)
|
| 94 |
+
vloss = loss_fn(voutputs, vlabels)
|
| 95 |
+
running_vloss += vloss
|
| 96 |
+
|
| 97 |
+
avg_vloss = running_vloss / (i + 1)
|
| 98 |
+
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
|
| 99 |
+
|
| 100 |
+
writer.add_scalars('Training vs. Validation Loss',
|
| 101 |
+
{'Training': avg_loss, 'Validation': avg_vloss},
|
| 102 |
+
epoch_number + 1)
|
| 103 |
+
for tag, wav_path in VALID_WAVS.items():
|
| 104 |
+
wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
|
| 105 |
+
wav = torch.reshape(wav, (1, 1, -1)).to(device)
|
| 106 |
+
prediction = model(wav)
|
| 107 |
+
writer.add_audio(tag=f"Model predicted {tag} on epoch {epoch}",
|
| 108 |
+
snd_tensor=prediction,
|
| 109 |
+
sample_rate=SAMPLE_RATE)
|
| 110 |
+
writer.flush()
|
| 111 |
+
|
| 112 |
+
if avg_vloss < best_vloss:
|
| 113 |
+
best_vloss = avg_vloss
|
| 114 |
+
model_path = 'checkpoints/model_{}_{}'.format(timestamp, epoch_number)
|
| 115 |
+
torch.save(model.state_dict(), model_path)
|
| 116 |
+
|
| 117 |
+
epoch_number += 1
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == '__main__':
|
| 121 |
+
train()
|