Spaces:
Sleeping
Sleeping
initial commit
Browse files- .gitattributes +5 -0
- data/models/UVR-MDX-NET-Inst_HQ_3.onnx +3 -0
- data/samples/result.mp4 +3 -0
- data/samples/temp.mp3 +3 -0
- data/samples/temp.mp4 +3 -0
- data/samples/temp_no_vocals.wav +3 -0
- data/samples/temp_vocals.wav +3 -0
- demo.py +16 -0
- model.py +123 -0
- packages.txt +2 -0
- requirements.txt +11 -0
- source_separation.py +291 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/samples/result.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/samples/temp_no_vocals.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
data/samples/temp_vocals.wav filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
data/samples/temp.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
data/samples/temp.mp4 filter=lfs diff=lfs merge=lfs -text
|
data/models/UVR-MDX-NET-Inst_HQ_3.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:317554b07fe1ea5279a77f2b1520a41ea4b93432560c4ffd08792c30fddf9adc
|
| 3 |
+
size 66759214
|
data/samples/result.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f3a5c839f552d27b110e7db77ac74cb41a5c51c6c8376a75814aa4fc5a0c5921
|
| 3 |
+
size 16601916
|
data/samples/temp.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1322f661bf6c9b22a6e30283933f223358ad68fab06d73017cb80363e6e3ff50
|
| 3 |
+
size 4749941
|
data/samples/temp.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:302dd0780f1420599fa5bc179eb766981aac39883b4b79f8f0273f94d11d2542
|
| 3 |
+
size 14761845
|
data/samples/temp_no_vocals.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:96ea44ba19641369a63e5ab8ec403e204b88e7aab35b7670f6af2b6811d912de
|
| 3 |
+
size 26179568
|
data/samples/temp_vocals.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:087b4afcc655ab2b0c0e25e196ee559bb661c996438ff897e5ef671cd51f4564
|
| 3 |
+
size 26179568
|
demo.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
from youtube_karaoke.model import get_karaoke
|
| 4 |
+
|
| 5 |
+
with gr.Blocks() as demo:
|
| 6 |
+
with gr.Row():
|
| 7 |
+
with gr.Column(), gr.Row():
|
| 8 |
+
url = gr.Textbox(placeholder="Youtube video URL", label="URL")
|
| 9 |
+
|
| 10 |
+
with gr.Column():
|
| 11 |
+
outputs = gr.PlayableVideo()
|
| 12 |
+
|
| 13 |
+
transcribe_btn = gr.Button("YouTube Karaoke")
|
| 14 |
+
transcribe_btn.click(get_karaoke, inputs=url, outputs=outputs)
|
| 15 |
+
|
| 16 |
+
demo.launch(debug=True)
|
model.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import soundfile as sf
|
| 3 |
+
|
| 4 |
+
# import torch
|
| 5 |
+
from moviepy import AudioFileClip, VideoFileClip
|
| 6 |
+
from pydub import AudioSegment
|
| 7 |
+
from pytubefix import YouTube
|
| 8 |
+
from pytubefix.cli import on_progress
|
| 9 |
+
|
| 10 |
+
# from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
| 11 |
+
from youtube_karaoke.source_separation import Predictor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def download_from_youtube(url, folder_path):
|
| 15 |
+
yt = YouTube(url, on_progress_callback=on_progress)
|
| 16 |
+
print(yt.title)
|
| 17 |
+
|
| 18 |
+
ys = yt.streams.get_highest_resolution()
|
| 19 |
+
ys.download(output_path=folder_path, filename="temp.mp4")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def separate_video_and_audio(video_path, audio_path):
|
| 23 |
+
# Load the video clip
|
| 24 |
+
video_clip = VideoFileClip(video_path)
|
| 25 |
+
|
| 26 |
+
# Extract the audio from the video clip
|
| 27 |
+
audio_clip = video_clip.audio
|
| 28 |
+
|
| 29 |
+
# Write the audio to a separate file
|
| 30 |
+
audio_clip.write_audiofile(audio_path)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_audio(audio_path, sample_rate=44_100):
|
| 34 |
+
audio = AudioSegment.from_file(audio_path)
|
| 35 |
+
|
| 36 |
+
print("Entering the preprocessing of audio")
|
| 37 |
+
|
| 38 |
+
# Convert the audio file to WAV format
|
| 39 |
+
audio = audio.set_frame_rate(sample_rate)
|
| 40 |
+
audio = audio.set_sample_width(2) # Set bit depth to 16bit
|
| 41 |
+
audio = audio.set_channels(1) # Set to mono
|
| 42 |
+
|
| 43 |
+
print("Audio file converted to WAV format")
|
| 44 |
+
|
| 45 |
+
# Calculate the gain to be applied
|
| 46 |
+
target_dBFS = -20
|
| 47 |
+
gain = target_dBFS - audio.dBFS
|
| 48 |
+
print(f"Calculating the gain needed for the audio: {gain} dB")
|
| 49 |
+
|
| 50 |
+
# Normalize volume and limit gain range to between -3 and 3
|
| 51 |
+
normalized_audio = audio.apply_gain(min(max(gain, -3), 3))
|
| 52 |
+
|
| 53 |
+
waveform = np.array(normalized_audio.get_array_of_samples(), dtype=np.float32)
|
| 54 |
+
max_amplitude = np.max(np.abs(waveform))
|
| 55 |
+
waveform /= max_amplitude # Normalize
|
| 56 |
+
|
| 57 |
+
print(f"waveform shape: {waveform.shape}")
|
| 58 |
+
print("waveform in np ndarray, dtype=" + str(waveform.dtype))
|
| 59 |
+
|
| 60 |
+
return waveform, sample_rate
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
args = {
|
| 64 |
+
"model_path": "data/models/UVR-MDX-NET-Inst_HQ_3.onnx",
|
| 65 |
+
"denoise": True,
|
| 66 |
+
"margin": 44100,
|
| 67 |
+
"chunks": 15,
|
| 68 |
+
"n_fft": 6144,
|
| 69 |
+
"dim_t": 8,
|
| 70 |
+
"dim_f": 3072,
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
separate_predictor = Predictor(args=args, device="cpu")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def source_separation(waveform):
|
| 77 |
+
"""
|
| 78 |
+
Separate the audio into vocals and non-vocals using the given predictor.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
predictor: The separation model predictor.
|
| 82 |
+
audio (str or dict): The audio file path or a dictionary containing audio waveform and sample rate.
|
| 83 |
+
|
| 84 |
+
Returns
|
| 85 |
+
-------
|
| 86 |
+
dict: A dictionary containing the separated vocals and updated audio waveform.
|
| 87 |
+
"""
|
| 88 |
+
vocals, no_vocals = separate_predictor.predict(waveform)
|
| 89 |
+
|
| 90 |
+
vocals = vocals[:, 0] # vocals is stereo, only use one channel
|
| 91 |
+
no_vocals = no_vocals[:, 0] # no_vocals is stereo, only use one channel
|
| 92 |
+
|
| 93 |
+
return vocals, no_vocals
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def export_to_wav(vocals, no_vocals, sample_rate, folder_path):
|
| 97 |
+
"""Export segmented audio to WAV files."""
|
| 98 |
+
sf.write(folder_path + "temp_vocals.wav", vocals, sample_rate)
|
| 99 |
+
sf.write(folder_path + "temp_no_vocals.wav", no_vocals, sample_rate)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def combine_video_and_audio(video_path, no_vocals_path, output_path):
|
| 103 |
+
my_clip = VideoFileClip(video_path, audio=False)
|
| 104 |
+
audio_background = AudioFileClip(no_vocals_path)
|
| 105 |
+
my_clip.audio = audio_background
|
| 106 |
+
my_clip.write_videofile(output_path)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# https://www.youtube.com/watch?v=1jZEyU_eO1s
|
| 110 |
+
def get_karaoke(url):
|
| 111 |
+
folder_path = "data/samples/"
|
| 112 |
+
video_path = folder_path + "temp.mp4"
|
| 113 |
+
audio_path = folder_path + "temp.mp3"
|
| 114 |
+
no_vocals_path = folder_path + "temp_no_vocals.wav"
|
| 115 |
+
output_path = folder_path + "result.mp4"
|
| 116 |
+
|
| 117 |
+
download_from_youtube(url, folder_path)
|
| 118 |
+
separate_video_and_audio(video_path, audio_path)
|
| 119 |
+
waveform, sample_rate = load_audio(audio_path)
|
| 120 |
+
vocals, no_vocals = source_separation(waveform)
|
| 121 |
+
export_to_wav(vocals, no_vocals, sample_rate, folder_path)
|
| 122 |
+
combine_video_and_audio(video_path, no_vocals_path, output_path)
|
| 123 |
+
return output_path
|
packages.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
| 2 |
+
libsndfile1
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
onnxruntime=="1.20.1"
|
| 2 |
+
torch=="2.5.1"
|
| 3 |
+
tqdm=="4.67.1"
|
| 4 |
+
llvmlite=="0.43.0"
|
| 5 |
+
librosa=="0.10.2.post1"
|
| 6 |
+
pydub=="0.25.1"
|
| 7 |
+
transformers=="4.47.0"
|
| 8 |
+
pytubefix=="8.8.1"
|
| 9 |
+
accelerate=="1.2.0"
|
| 10 |
+
moviepy=="2.1.1"
|
| 11 |
+
gradio=="5.8.0"
|
source_separation.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 seanghay
|
| 2 |
+
#
|
| 3 |
+
# This code is from an unliscensed repository.
|
| 4 |
+
#
|
| 5 |
+
# Note: This code has been modified to fit the context of this repository.
|
| 6 |
+
# This code is included in an MIT-licensed repository.
|
| 7 |
+
# The repository's MIT license does not apply to this code.
|
| 8 |
+
|
| 9 |
+
# This code is modified from https://github.com/seanghay/uvr-mdx-infer/blob/main/separate.py
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import onnxruntime as ort
|
| 13 |
+
import torch
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ConvTDFNet:
|
| 18 |
+
"""
|
| 19 |
+
ConvTDFNet - Convolutional Temporal Frequency Domain Network.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, target_name, L, dim_f, dim_t, n_fft, hop=1024):
|
| 23 |
+
"""
|
| 24 |
+
Initialize ConvTDFNet.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
target_name (str): The target name for separation.
|
| 28 |
+
L (int): Number of layers.
|
| 29 |
+
dim_f (int): Dimension in the frequency domain.
|
| 30 |
+
dim_t (int): Dimension in the time domain (log2).
|
| 31 |
+
n_fft (int): FFT size.
|
| 32 |
+
hop (int, optional): Hop size. Defaults to 1024.
|
| 33 |
+
|
| 34 |
+
Returns
|
| 35 |
+
-------
|
| 36 |
+
None
|
| 37 |
+
"""
|
| 38 |
+
super(ConvTDFNet, self).__init__()
|
| 39 |
+
self.dim_c = 4
|
| 40 |
+
self.dim_f = dim_f
|
| 41 |
+
self.dim_t = 2**dim_t
|
| 42 |
+
self.n_fft = n_fft
|
| 43 |
+
self.hop = hop
|
| 44 |
+
self.n_bins = self.n_fft // 2 + 1
|
| 45 |
+
self.chunk_size = hop * (self.dim_t - 1)
|
| 46 |
+
self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
| 47 |
+
self.target_name = target_name
|
| 48 |
+
|
| 49 |
+
out_c = self.dim_c * 4 if target_name == "*" else self.dim_c
|
| 50 |
+
|
| 51 |
+
self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t])
|
| 52 |
+
self.n = L // 2
|
| 53 |
+
|
| 54 |
+
def stft(self, x):
|
| 55 |
+
"""
|
| 56 |
+
Perform Short-Time Fourier Transform (STFT).
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
x (torch.Tensor): Input waveform.
|
| 60 |
+
|
| 61 |
+
Returns
|
| 62 |
+
-------
|
| 63 |
+
torch.Tensor: STFT of the input waveform.
|
| 64 |
+
"""
|
| 65 |
+
x = x.reshape([-1, self.chunk_size])
|
| 66 |
+
x = torch.stft(
|
| 67 |
+
x,
|
| 68 |
+
n_fft=self.n_fft,
|
| 69 |
+
hop_length=self.hop,
|
| 70 |
+
window=self.window,
|
| 71 |
+
center=True,
|
| 72 |
+
return_complex=True,
|
| 73 |
+
)
|
| 74 |
+
x = torch.view_as_real(x)
|
| 75 |
+
x = x.permute([0, 3, 1, 2])
|
| 76 |
+
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
|
| 77 |
+
[-1, self.dim_c, self.n_bins, self.dim_t]
|
| 78 |
+
)
|
| 79 |
+
return x[:, :, : self.dim_f]
|
| 80 |
+
|
| 81 |
+
def istft(self, x, freq_pad=None):
|
| 82 |
+
"""
|
| 83 |
+
Perform Inverse Short-Time Fourier Transform (ISTFT).
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
x (torch.Tensor): Input STFT.
|
| 87 |
+
freq_pad (torch.Tensor, optional): Frequency padding. Defaults to None.
|
| 88 |
+
|
| 89 |
+
Returns
|
| 90 |
+
-------
|
| 91 |
+
torch.Tensor: Inverse STFT of the input.
|
| 92 |
+
"""
|
| 93 |
+
freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
|
| 94 |
+
x = torch.cat([x, freq_pad], -2)
|
| 95 |
+
c = 4 * 2 if self.target_name == "*" else 2
|
| 96 |
+
x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
|
| 97 |
+
x = x.permute([0, 2, 3, 1])
|
| 98 |
+
x = x.contiguous()
|
| 99 |
+
x = torch.view_as_complex(x)
|
| 100 |
+
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
|
| 101 |
+
return x.reshape([-1, c, self.chunk_size])
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Predictor:
|
| 105 |
+
"""
|
| 106 |
+
Predictor class for source separation using ConvTDFNet and ONNX Runtime.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def __init__(self, args, device):
|
| 110 |
+
"""
|
| 111 |
+
Initialize the Predictor.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
args (dict): Configuration arguments.
|
| 115 |
+
device (str): Device to run the model ('cuda' or 'cpu').
|
| 116 |
+
|
| 117 |
+
Returns
|
| 118 |
+
-------
|
| 119 |
+
None
|
| 120 |
+
|
| 121 |
+
Raises
|
| 122 |
+
------
|
| 123 |
+
ValueError: If the provided device is not 'cuda' or 'cpu'.
|
| 124 |
+
"""
|
| 125 |
+
self.args = args
|
| 126 |
+
self.model_ = ConvTDFNet(
|
| 127 |
+
target_name="vocals",
|
| 128 |
+
L=11,
|
| 129 |
+
dim_f=args["dim_f"],
|
| 130 |
+
dim_t=args["dim_t"],
|
| 131 |
+
n_fft=args["n_fft"],
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
if device == "cuda":
|
| 135 |
+
self.model = ort.InferenceSession(
|
| 136 |
+
args["model_path"], providers=["CUDAExecutionProvider"]
|
| 137 |
+
)
|
| 138 |
+
elif device == "cpu":
|
| 139 |
+
self.model = ort.InferenceSession(
|
| 140 |
+
args["model_path"], providers=["CPUExecutionProvider"]
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
raise ValueError("Device must be either 'cuda' or 'cpu'")
|
| 144 |
+
|
| 145 |
+
def demix(self, mix):
|
| 146 |
+
"""
|
| 147 |
+
Separate the sources from the input mix.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
mix (np.ndarray): Input mixture signal.
|
| 151 |
+
|
| 152 |
+
Returns
|
| 153 |
+
-------
|
| 154 |
+
np.ndarray: Separated sources.
|
| 155 |
+
|
| 156 |
+
Raises
|
| 157 |
+
------
|
| 158 |
+
AssertionError: If margin is zero.
|
| 159 |
+
"""
|
| 160 |
+
samples = mix.shape[-1]
|
| 161 |
+
margin = self.args["margin"]
|
| 162 |
+
chunk_size = self.args["chunks"] * 44100
|
| 163 |
+
|
| 164 |
+
assert margin != 0, "Margin cannot be zero!"
|
| 165 |
+
|
| 166 |
+
margin = min(margin, chunk_size)
|
| 167 |
+
|
| 168 |
+
segmented_mix = {}
|
| 169 |
+
|
| 170 |
+
if self.args["chunks"] == 0 or samples < chunk_size:
|
| 171 |
+
chunk_size = samples
|
| 172 |
+
|
| 173 |
+
counter = -1
|
| 174 |
+
for skip in range(0, samples, chunk_size):
|
| 175 |
+
counter += 1
|
| 176 |
+
s_margin = 0 if counter == 0 else margin
|
| 177 |
+
end = min(skip + chunk_size + margin, samples)
|
| 178 |
+
start = skip - s_margin
|
| 179 |
+
segmented_mix[skip] = mix[:, start:end].copy()
|
| 180 |
+
if end == samples:
|
| 181 |
+
break
|
| 182 |
+
|
| 183 |
+
sources = self.demix_base(segmented_mix, margin_size=margin)
|
| 184 |
+
return sources
|
| 185 |
+
|
| 186 |
+
def demix_base(self, mixes, margin_size):
|
| 187 |
+
"""
|
| 188 |
+
Base function for source separation.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
mixes (dict): Dictionary of segmented mixtures.
|
| 192 |
+
margin_size (int): Size of the margin.
|
| 193 |
+
|
| 194 |
+
Returns
|
| 195 |
+
-------
|
| 196 |
+
np.ndarray: Separated sources.
|
| 197 |
+
"""
|
| 198 |
+
chunked_sources = []
|
| 199 |
+
progress_bar = tqdm(total=len(mixes))
|
| 200 |
+
progress_bar.set_description("Source separation")
|
| 201 |
+
|
| 202 |
+
for mix in mixes:
|
| 203 |
+
cmix = mixes[mix]
|
| 204 |
+
sources = []
|
| 205 |
+
n_sample = cmix.shape[1]
|
| 206 |
+
model = self.model_
|
| 207 |
+
trim = model.n_fft // 2
|
| 208 |
+
gen_size = model.chunk_size - 2 * trim
|
| 209 |
+
pad = gen_size - n_sample % gen_size
|
| 210 |
+
mix_p = np.concatenate(
|
| 211 |
+
(np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1
|
| 212 |
+
)
|
| 213 |
+
mix_waves = []
|
| 214 |
+
i = 0
|
| 215 |
+
while i < n_sample + pad:
|
| 216 |
+
waves = np.array(mix_p[:, i : i + model.chunk_size])
|
| 217 |
+
mix_waves.append(waves)
|
| 218 |
+
i += gen_size
|
| 219 |
+
|
| 220 |
+
mix_waves = torch.tensor(np.array(mix_waves), dtype=torch.float32)
|
| 221 |
+
|
| 222 |
+
with torch.no_grad():
|
| 223 |
+
_ort = self.model
|
| 224 |
+
spek = model.stft(mix_waves)
|
| 225 |
+
if self.args["denoise"]:
|
| 226 |
+
spec_pred = (
|
| 227 |
+
-_ort.run(None, {"input": -spek.cpu().numpy()})[0] * 0.5
|
| 228 |
+
+ _ort.run(None, {"input": spek.cpu().numpy()})[0] * 0.5
|
| 229 |
+
)
|
| 230 |
+
tar_waves = model.istft(torch.tensor(spec_pred))
|
| 231 |
+
else:
|
| 232 |
+
tar_waves = model.istft(
|
| 233 |
+
torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0])
|
| 234 |
+
)
|
| 235 |
+
tar_signal = (
|
| 236 |
+
tar_waves[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).numpy()[:, :-pad]
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
start = 0 if mix == 0 else margin_size
|
| 240 |
+
end = None if mix == list(mixes.keys())[::-1][0] else -margin_size
|
| 241 |
+
|
| 242 |
+
if margin_size == 0:
|
| 243 |
+
end = None
|
| 244 |
+
|
| 245 |
+
sources.append(tar_signal[:, start:end])
|
| 246 |
+
|
| 247 |
+
progress_bar.update(1)
|
| 248 |
+
|
| 249 |
+
chunked_sources.append(sources)
|
| 250 |
+
_sources = np.concatenate(chunked_sources, axis=-1)
|
| 251 |
+
|
| 252 |
+
progress_bar.close()
|
| 253 |
+
return _sources
|
| 254 |
+
|
| 255 |
+
def predict(self, mix):
|
| 256 |
+
"""
|
| 257 |
+
Predict the separated sources from the input mix.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
mix (np.ndarray): Input mixture signal.
|
| 261 |
+
|
| 262 |
+
Returns
|
| 263 |
+
-------
|
| 264 |
+
tuple: Tuple containing the mixture minus the separated sources and the separated sources.
|
| 265 |
+
"""
|
| 266 |
+
if mix.ndim == 1:
|
| 267 |
+
mix = np.asfortranarray([mix, mix])
|
| 268 |
+
|
| 269 |
+
tail = mix.shape[1] % (self.args["chunks"] * 44100)
|
| 270 |
+
if mix.shape[1] % (self.args["chunks"] * 44100) != 0:
|
| 271 |
+
mix = np.pad(
|
| 272 |
+
mix,
|
| 273 |
+
(
|
| 274 |
+
(0, 0),
|
| 275 |
+
(
|
| 276 |
+
0,
|
| 277 |
+
self.args["chunks"] * 44100 - mix.shape[1] % (self.args["chunks"] * 44100),
|
| 278 |
+
),
|
| 279 |
+
),
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
mix = mix.T
|
| 283 |
+
sources = self.demix(mix.T)
|
| 284 |
+
opt = sources[0].T
|
| 285 |
+
|
| 286 |
+
if tail != 0:
|
| 287 |
+
return (
|
| 288 |
+
(mix - opt)[: -(self.args["chunks"] * 44100 - tail), :],
|
| 289 |
+
opt[: -(self.args["chunks"] * 44100 - tail), :],
|
| 290 |
+
)
|
| 291 |
+
return ((mix - opt), opt)
|