Spaces:
Running
Running
use pth rather than ckpt
Browse files- ans.py +10 -10
- app.py +17 -17
- inference.py +11 -5
- inference_full.py +3 -3
- start_script.sh +9 -6
ans.py
CHANGED
|
@@ -7,21 +7,21 @@ import soundfile as sf
|
|
| 7 |
from inference import load_config_and_state_dict, load_generator, process_audio
|
| 8 |
|
| 9 |
MSS_MODEL_PATHS = {
|
| 10 |
-
'vox': ['checkpoints/vox_mss.
|
| 11 |
-
'gtr': ['checkpoints/gtr_mss.
|
| 12 |
-
'key': ['checkpoints/key_mss.
|
| 13 |
-
'syn': ['checkpoints/syn_mss.
|
| 14 |
-
'bass': ['checkpoints/bass_mss.
|
| 15 |
-
'drums': ['checkpoints/drums_mss.
|
| 16 |
-
'perc': ['checkpoints/perc_mss.
|
| 17 |
-
'orch': ['checkpoints/orch_mss.
|
| 18 |
}
|
| 19 |
MSS_MODELS = {key: [] for key in MSS_MODEL_PATHS.keys()}
|
| 20 |
|
| 21 |
-
PRE_MODEL_PATHS = {'vox': ['checkpoints/denoise.
|
| 22 |
PRE_MODELS = {key: [] for key in PRE_MODEL_PATHS.keys()}
|
| 23 |
|
| 24 |
-
POST_MODEL_PATHS = {key: ['checkpoints/dereverb.
|
| 25 |
POST_MODELS = {key: [] for key in POST_MODEL_PATHS.keys()}
|
| 26 |
|
| 27 |
OUTPUT_PATHS = {
|
|
|
|
| 7 |
from inference import load_config_and_state_dict, load_generator, process_audio
|
| 8 |
|
| 9 |
MSS_MODEL_PATHS = {
|
| 10 |
+
'vox': ['checkpoints/vox_mss.pth'],
|
| 11 |
+
'gtr': ['checkpoints/gtr_mss.pth'],
|
| 12 |
+
'key': ['checkpoints/key_mss.pth'],
|
| 13 |
+
'syn': ['checkpoints/syn_mss.pth', 'checkpoints/syn_mss1.pth'],
|
| 14 |
+
'bass': ['checkpoints/bass_mss.pth'],
|
| 15 |
+
'drums': ['checkpoints/drums_mss.pth', 'checkpoints/drums_mss1.pth'],
|
| 16 |
+
'perc': ['checkpoints/perc_mss.pth', 'checkpoints/perc_mss1.pth'],
|
| 17 |
+
'orch': ['checkpoints/orch_mss.pth', 'checkpoints/orch_mss1.pth'],
|
| 18 |
}
|
| 19 |
MSS_MODELS = {key: [] for key in MSS_MODEL_PATHS.keys()}
|
| 20 |
|
| 21 |
+
PRE_MODEL_PATHS = {'vox': ['checkpoints/denoise.pth']}
|
| 22 |
PRE_MODELS = {key: [] for key in PRE_MODEL_PATHS.keys()}
|
| 23 |
|
| 24 |
+
POST_MODEL_PATHS = {key: ['checkpoints/dereverb.pth'] if key in ['vox'] else [] for key in MSS_MODEL_PATHS.keys()}
|
| 25 |
POST_MODELS = {key: [] for key in POST_MODEL_PATHS.keys()}
|
| 26 |
|
| 27 |
OUTPUT_PATHS = {
|
app.py
CHANGED
|
@@ -14,43 +14,43 @@ REPO_ID = os.getenv("MODEL_REPO_ID", "chenxie95/xlance-msr-ckpt")
|
|
| 14 |
# Instrument to checkpoint mapping
|
| 15 |
INSTRUMENT_MAP = {
|
| 16 |
'vox': {
|
| 17 |
-
'pre': ['denoise.
|
| 18 |
-
'mss': ['vox_mss.
|
| 19 |
-
'post': ['dereverb.
|
| 20 |
},
|
| 21 |
'gtr': {
|
| 22 |
-
'pre': ['denoise.
|
| 23 |
-
'mss': ['gtr_mss.
|
| 24 |
'post': []
|
| 25 |
},
|
| 26 |
'key': {
|
| 27 |
-
'pre': ['denoise.
|
| 28 |
-
'mss': ['key_mss.
|
| 29 |
'post': []
|
| 30 |
},
|
| 31 |
'syn': {
|
| 32 |
-
'pre': ['denoise.
|
| 33 |
-
'mss': ['syn_mss.
|
| 34 |
'post': []
|
| 35 |
},
|
| 36 |
'bass': {
|
| 37 |
-
'pre': ['denoise.
|
| 38 |
-
'mss': ['bass_mss.
|
| 39 |
'post': []
|
| 40 |
},
|
| 41 |
'drums': {
|
| 42 |
-
'pre': ['denoise.
|
| 43 |
-
'mss': ['drums_mss.
|
| 44 |
'post': []
|
| 45 |
},
|
| 46 |
'perc': {
|
| 47 |
-
'pre': ['denoise.
|
| 48 |
-
'mss': ['perc_mss.
|
| 49 |
'post': []
|
| 50 |
},
|
| 51 |
'orch': {
|
| 52 |
-
'pre': ['denoise.
|
| 53 |
-
'mss': ['orch_mss.
|
| 54 |
'post': []
|
| 55 |
}
|
| 56 |
}
|
|
|
|
| 14 |
# Instrument to checkpoint mapping
|
| 15 |
INSTRUMENT_MAP = {
|
| 16 |
'vox': {
|
| 17 |
+
'pre': ['denoise.pth'],
|
| 18 |
+
'mss': ['vox_mss.pth'],
|
| 19 |
+
'post': ['dereverb.pth']
|
| 20 |
},
|
| 21 |
'gtr': {
|
| 22 |
+
'pre': ['denoise.pth'],
|
| 23 |
+
'mss': ['gtr_mss.pth'],
|
| 24 |
'post': []
|
| 25 |
},
|
| 26 |
'key': {
|
| 27 |
+
'pre': ['denoise.pth'],
|
| 28 |
+
'mss': ['key_mss.pth'],
|
| 29 |
'post': []
|
| 30 |
},
|
| 31 |
'syn': {
|
| 32 |
+
'pre': ['denoise.pth'],
|
| 33 |
+
'mss': ['syn_mss.pth', 'syn_mss1.pth'],
|
| 34 |
'post': []
|
| 35 |
},
|
| 36 |
'bass': {
|
| 37 |
+
'pre': ['denoise.pth'],
|
| 38 |
+
'mss': ['bass_mss.pth'],
|
| 39 |
'post': []
|
| 40 |
},
|
| 41 |
'drums': {
|
| 42 |
+
'pre': ['denoise.pth'],
|
| 43 |
+
'mss': ['drums_mss.pth', 'drums_mss1.pth'],
|
| 44 |
'post': []
|
| 45 |
},
|
| 46 |
'perc': {
|
| 47 |
+
'pre': ['denoise.pth'],
|
| 48 |
+
'mss': ['perc_mss.pth', 'perc_mss1.pth'],
|
| 49 |
'post': []
|
| 50 |
},
|
| 51 |
'orch': {
|
| 52 |
+
'pre': ['denoise.pth'],
|
| 53 |
+
'mss': ['orch_mss.pth', 'orch_mss1.pth'],
|
| 54 |
'post': []
|
| 55 |
}
|
| 56 |
}
|
inference.py
CHANGED
|
@@ -14,6 +14,7 @@ import soundfile as sf
|
|
| 14 |
import numpy as np
|
| 15 |
from tqdm import tqdm
|
| 16 |
import librosa
|
|
|
|
| 17 |
from models import MelRNN, MelRoFormer, UNet, UFormer
|
| 18 |
from models.bs_roformer import bs_roformer as BSRoformer
|
| 19 |
from models.bs_roformer import mel_band_roformer as MelBandRoformer
|
|
@@ -55,9 +56,14 @@ class RoformerSequential(nn.Sequential):
|
|
| 55 |
return self[-1](mixture, target) # also pass target if present
|
| 56 |
|
| 57 |
def load_config_and_state_dict(path: str, map_location: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 58 |
-
if path.endswith('.pth'):
|
| 59 |
-
raise ValueError("Use .ckpt files instead of .pth files")
|
| 60 |
print(f"Extracting state dict from {path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
full_checkpoint = torch.load(path, map_location=map_location, weights_only=False)
|
| 62 |
full_state_dict = full_checkpoint['state_dict']
|
| 63 |
generator_state_dict = OrderedDict()
|
|
@@ -119,9 +125,9 @@ def process_audio(config, audio: np.ndarray, generator: nn.Module, device: str =
|
|
| 119 |
|
| 120 |
def main():
|
| 121 |
parser = argparse.ArgumentParser(description="Run inference on audio files using trained generator")
|
| 122 |
-
parser.add_argument("--checkpoint", '-c', type=str, required=True, help="Path to unwrapped generator weights (.ckpt)")
|
| 123 |
-
parser.add_argument("--checkpoint_pre", '-p', type=str, help="pre-processing model checkpoint (.ckpt)")
|
| 124 |
-
parser.add_argument("--checkpoint_post", '-P', type=str, help="post-processing model checkpoint (.ckpt)")
|
| 125 |
parser.add_argument("--input_dir", '-i', type=str, help="Directory containing input .flac files")
|
| 126 |
parser.add_argument("--output_dir", '-o', type=str, help="Directory to save processed audio")
|
| 127 |
parser.add_argument("--instrument", type=str, help="Instrument to process (Vox/Gtr/Kbs/Synth/Bass/Rhy_DK/Rhy_PERC/Orch)")
|
|
|
|
| 14 |
import numpy as np
|
| 15 |
from tqdm import tqdm
|
| 16 |
import librosa
|
| 17 |
+
import yaml
|
| 18 |
from models import MelRNN, MelRoFormer, UNet, UFormer
|
| 19 |
from models.bs_roformer import bs_roformer as BSRoformer
|
| 20 |
from models.bs_roformer import mel_band_roformer as MelBandRoformer
|
|
|
|
| 56 |
return self[-1](mixture, target) # also pass target if present
|
| 57 |
|
| 58 |
def load_config_and_state_dict(path: str, map_location: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
|
|
|
|
|
| 59 |
print(f"Extracting state dict from {path}")
|
| 60 |
+
if path.endswith('.pth'):
|
| 61 |
+
model_name = Path(path).stem
|
| 62 |
+
config_path = f"./configs/{model_name}.yaml" # use config file with same name as model in ./configs
|
| 63 |
+
print(f"Loading config from {config_path}")
|
| 64 |
+
with open(config_path, 'r') as f:
|
| 65 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 66 |
+
return config, torch.load(path, map_location=map_location)
|
| 67 |
full_checkpoint = torch.load(path, map_location=map_location, weights_only=False)
|
| 68 |
full_state_dict = full_checkpoint['state_dict']
|
| 69 |
generator_state_dict = OrderedDict()
|
|
|
|
| 125 |
|
| 126 |
def main():
|
| 127 |
parser = argparse.ArgumentParser(description="Run inference on audio files using trained generator")
|
| 128 |
+
parser.add_argument("--checkpoint", '-c', type=str, required=True, help="Path to unwrapped generator weights (.ckpt or .pth)")
|
| 129 |
+
parser.add_argument("--checkpoint_pre", '-p', type=str, help="pre-processing model checkpoint (.ckpt or .pth)")
|
| 130 |
+
parser.add_argument("--checkpoint_post", '-P', type=str, help="post-processing model checkpoint (.ckpt or .pth)")
|
| 131 |
parser.add_argument("--input_dir", '-i', type=str, help="Directory containing input .flac files")
|
| 132 |
parser.add_argument("--output_dir", '-o', type=str, help="Directory to save processed audio")
|
| 133 |
parser.add_argument("--instrument", type=str, help="Instrument to process (Vox/Gtr/Kbs/Synth/Bass/Rhy_DK/Rhy_PERC/Orch)")
|
inference_full.py
CHANGED
|
@@ -171,9 +171,9 @@ def inference_main(args):
|
|
| 171 |
|
| 172 |
if __name__ == '__main__':
|
| 173 |
parser = argparse.ArgumentParser(description="Run inference on audio files using trained generator")
|
| 174 |
-
parser.add_argument("--checkpoint", '-c', nargs='*', default=[], type=str, help="model checkpoint (.ckpt)")
|
| 175 |
-
parser.add_argument("--checkpoint_pre", '-p', nargs='*', default=[], type=str, help="pre-processing model checkpoint (.ckpt)")
|
| 176 |
-
parser.add_argument("--checkpoint_post", '-P', nargs='*', default=[], type=str, help="post-processing model checkpoint (.ckpt)")
|
| 177 |
parser.add_argument("--input_dir", '-i', type=str, help="Directory containing input files, or a single audio file")
|
| 178 |
parser.add_argument("--output_dir", '-o', type=str, help="Directory to save processed audio, or a single audio file name")
|
| 179 |
parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (cuda/cpu)")
|
|
|
|
| 171 |
|
| 172 |
if __name__ == '__main__':
|
| 173 |
parser = argparse.ArgumentParser(description="Run inference on audio files using trained generator")
|
| 174 |
+
parser.add_argument("--checkpoint", '-c', nargs='*', default=[], type=str, help="model checkpoint (.ckpt or .pth)")
|
| 175 |
+
parser.add_argument("--checkpoint_pre", '-p', nargs='*', default=[], type=str, help="pre-processing model checkpoint (.ckpt or .pth)")
|
| 176 |
+
parser.add_argument("--checkpoint_post", '-P', nargs='*', default=[], type=str, help="post-processing model checkpoint (.ckpt or .pth)")
|
| 177 |
parser.add_argument("--input_dir", '-i', type=str, help="Directory containing input files, or a single audio file")
|
| 178 |
parser.add_argument("--output_dir", '-o', type=str, help="Directory to save processed audio, or a single audio file name")
|
| 179 |
parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (cuda/cpu)")
|
start_script.sh
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
# inference_full (for inference on full audio)
|
| 2 |
python inference_full.py \
|
| 3 |
-
-p checkpoints/denoise.
|
| 4 |
-
-c checkpoints/vox_mss.
|
| 5 |
-
-P checkpoints/dereverb.
|
| 6 |
-i test/input \
|
| 7 |
-o test/output \
|
| 8 |
--device cuda \
|
|
@@ -10,8 +10,8 @@ python inference_full.py \
|
|
| 10 |
|
| 11 |
# another example (eight instruments are supported in total)
|
| 12 |
python inference_full.py \
|
| 13 |
-
-p checkpoints/denoise.
|
| 14 |
-
-c checkpoints/drums_mss.
|
| 15 |
-i test/input \
|
| 16 |
-o test/output \
|
| 17 |
--device cuda \
|
|
@@ -39,4 +39,7 @@ python inference.py \
|
|
| 39 |
-P logs/dereverb/vox/checkpoints/00000010.ckpt \
|
| 40 |
-i OrganizersMixture/Vocals \
|
| 41 |
-o Answer/Vocals/default \
|
| 42 |
-
--no-eval
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# inference_full (for inference on full audio)
|
| 2 |
python inference_full.py \
|
| 3 |
+
-p checkpoints/denoise.pth \
|
| 4 |
+
-c checkpoints/vox_mss.pth \
|
| 5 |
+
-P checkpoints/dereverb.pth \ # only vox need it
|
| 6 |
-i test/input \
|
| 7 |
-o test/output \
|
| 8 |
--device cuda \
|
|
|
|
| 10 |
|
| 11 |
# another example (eight instruments are supported in total)
|
| 12 |
python inference_full.py \
|
| 13 |
+
-p checkpoints/denoise.pth \
|
| 14 |
+
-c checkpoints/drums_mss.pth checkpoints/drums_mss1.pth \ # use multiple checkpoints for some instruments
|
| 15 |
-i test/input \
|
| 16 |
-o test/output \
|
| 17 |
--device cuda \
|
|
|
|
| 39 |
-P logs/dereverb/vox/checkpoints/00000010.ckpt \
|
| 40 |
-i OrganizersMixture/Vocals \
|
| 41 |
-o Answer/Vocals/default \
|
| 42 |
+
--no-eval
|
| 43 |
+
|
| 44 |
+
# unwrap
|
| 45 |
+
python unwrap.py -i checkpoints/bass_mss.ckpt -o checkpoints/bass_mss.pth
|