Update app.py
Browse files
app.py
CHANGED
|
@@ -4,38 +4,50 @@ from sgmse.model import ScoreModel
|
|
| 4 |
import gradio as gr
|
| 5 |
from sgmse.util.other import pad_spec
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
# Load the pre-trained model
|
| 8 |
-
model = ScoreModel.load_from_checkpoint("
|
| 9 |
|
| 10 |
def enhance_speech(audio_file):
|
| 11 |
# Load and process the audio file
|
| 12 |
y, sr = torchaudio.load(audio_file)
|
| 13 |
-
|
| 14 |
-
T_orig = y.size(1)
|
| 15 |
|
| 16 |
# Normalize
|
| 17 |
norm_factor = y.abs().max()
|
| 18 |
y = y / norm_factor
|
| 19 |
-
|
| 20 |
# Prepare DNN input
|
| 21 |
-
Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args
|
| 22 |
-
Y = pad_spec(Y, mode=pad_mode
|
| 23 |
-
|
| 24 |
# Reverse sampling
|
| 25 |
sampler = model.get_pc_sampler(
|
| 26 |
-
'reverse_diffusion', args
|
| 27 |
-
corrector_steps=args
|
|
|
|
| 28 |
sample, _ = sampler()
|
| 29 |
-
|
| 30 |
# Backward transform in time domain
|
| 31 |
x_hat = model.to_audio(sample.squeeze(), T_orig)
|
| 32 |
|
| 33 |
# Renormalize
|
| 34 |
x_hat = x_hat * norm_factor
|
| 35 |
-
|
| 36 |
# Save the enhanced audio
|
| 37 |
output_file = 'enhanced_output.wav'
|
| 38 |
-
torchaudio.save(output_file, x_hat.cpu()
|
| 39 |
|
| 40 |
return output_file
|
| 41 |
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
from sgmse.util.other import pad_spec
|
| 6 |
|
| 7 |
+
# Define parameters based on the argparse configuration in enhancement.py
|
| 8 |
+
args = {
|
| 9 |
+
"test_dir": "./test_data", # example directory, adjust as needed
|
| 10 |
+
"enhanced_dir": "./enhanced_data", # example directory, adjust as needed
|
| 11 |
+
"ckpt": "https://huggingface.co/sp-uhh/speech-enhancement-sgmse/resolve/main/train_vb_29nqe0uh_epoch%3D115.ckpt",
|
| 12 |
+
"corrector": "ald",
|
| 13 |
+
"corrector_steps": 1,
|
| 14 |
+
"snr": 0.5,
|
| 15 |
+
"N": 30,
|
| 16 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu"
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
# Load the pre-trained model
|
| 20 |
+
model = ScoreModel.load_from_checkpoint(args["ckpt"])
|
| 21 |
|
| 22 |
def enhance_speech(audio_file):
|
| 23 |
# Load and process the audio file
|
| 24 |
y, sr = torchaudio.load(audio_file)
|
| 25 |
+
T_orig = y.size(1)
|
|
|
|
| 26 |
|
| 27 |
# Normalize
|
| 28 |
norm_factor = y.abs().max()
|
| 29 |
y = y / norm_factor
|
| 30 |
+
|
| 31 |
# Prepare DNN input
|
| 32 |
+
Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args["device"]))), 0)
|
| 33 |
+
Y = pad_spec(Y, mode="constant") # Ensure pad_mode is defined; replace with actual pad_mode if needed
|
| 34 |
+
|
| 35 |
# Reverse sampling
|
| 36 |
sampler = model.get_pc_sampler(
|
| 37 |
+
'reverse_diffusion', args["corrector"], Y.to(args["device"]),
|
| 38 |
+
N=args["N"], corrector_steps=args["corrector_steps"], snr=args["snr"]
|
| 39 |
+
)
|
| 40 |
sample, _ = sampler()
|
| 41 |
+
|
| 42 |
# Backward transform in time domain
|
| 43 |
x_hat = model.to_audio(sample.squeeze(), T_orig)
|
| 44 |
|
| 45 |
# Renormalize
|
| 46 |
x_hat = x_hat * norm_factor
|
| 47 |
+
|
| 48 |
# Save the enhanced audio
|
| 49 |
output_file = 'enhanced_output.wav'
|
| 50 |
+
torchaudio.save(output_file, x_hat.cpu(), sr)
|
| 51 |
|
| 52 |
return output_file
|
| 53 |
|