Jihuai commited on
Commit
fd56f2f
·
1 Parent(s): 2c60d44

use pth rather than ckpt

Browse files
Files changed (5) hide show
  1. ans.py +10 -10
  2. app.py +17 -17
  3. inference.py +11 -5
  4. inference_full.py +3 -3
  5. start_script.sh +9 -6
ans.py CHANGED
@@ -7,21 +7,21 @@ import soundfile as sf
7
  from inference import load_config_and_state_dict, load_generator, process_audio
8
 
9
  MSS_MODEL_PATHS = {
10
- 'vox': ['checkpoints/vox_mss.ckpt'],
11
- 'gtr': ['checkpoints/gtr_mss.ckpt'],
12
- 'key': ['checkpoints/key_mss.ckpt'],
13
- 'syn': ['checkpoints/syn_mss.ckpt', 'checkpoints/syn_mss1.ckpt'],
14
- 'bass': ['checkpoints/bass_mss.ckpt'],
15
- 'drums': ['checkpoints/drums_mss.ckpt', 'checkpoints/drums_mss1.ckpt'],
16
- 'perc': ['checkpoints/perc_mss.ckpt', 'checkpoints/perc_mss1.ckpt'],
17
- 'orch': ['checkpoints/orch_mss.ckpt', 'checkpoints/orch_mss1.ckpt'],
18
  }
19
  MSS_MODELS = {key: [] for key in MSS_MODEL_PATHS.keys()}
20
 
21
- PRE_MODEL_PATHS = {'vox': ['checkpoints/denoise.ckpt']}
22
  PRE_MODELS = {key: [] for key in PRE_MODEL_PATHS.keys()}
23
 
24
- POST_MODEL_PATHS = {key: ['checkpoints/dereverb.ckpt'] if key in ['vox'] else [] for key in MSS_MODEL_PATHS.keys()}
25
  POST_MODELS = {key: [] for key in POST_MODEL_PATHS.keys()}
26
 
27
  OUTPUT_PATHS = {
 
7
  from inference import load_config_and_state_dict, load_generator, process_audio
8
 
9
  MSS_MODEL_PATHS = {
10
+ 'vox': ['checkpoints/vox_mss.pth'],
11
+ 'gtr': ['checkpoints/gtr_mss.pth'],
12
+ 'key': ['checkpoints/key_mss.pth'],
13
+ 'syn': ['checkpoints/syn_mss.pth', 'checkpoints/syn_mss1.pth'],
14
+ 'bass': ['checkpoints/bass_mss.pth'],
15
+ 'drums': ['checkpoints/drums_mss.pth', 'checkpoints/drums_mss1.pth'],
16
+ 'perc': ['checkpoints/perc_mss.pth', 'checkpoints/perc_mss1.pth'],
17
+ 'orch': ['checkpoints/orch_mss.pth', 'checkpoints/orch_mss1.pth'],
18
  }
19
  MSS_MODELS = {key: [] for key in MSS_MODEL_PATHS.keys()}
20
 
21
+ PRE_MODEL_PATHS = {'vox': ['checkpoints/denoise.pth']}
22
  PRE_MODELS = {key: [] for key in PRE_MODEL_PATHS.keys()}
23
 
24
+ POST_MODEL_PATHS = {key: ['checkpoints/dereverb.pth'] if key in ['vox'] else [] for key in MSS_MODEL_PATHS.keys()}
25
  POST_MODELS = {key: [] for key in POST_MODEL_PATHS.keys()}
26
 
27
  OUTPUT_PATHS = {
app.py CHANGED
@@ -14,43 +14,43 @@ REPO_ID = os.getenv("MODEL_REPO_ID", "chenxie95/xlance-msr-ckpt")
14
  # Instrument to checkpoint mapping
15
  INSTRUMENT_MAP = {
16
  'vox': {
17
- 'pre': ['denoise.ckpt'],
18
- 'mss': ['vox_mss.ckpt'],
19
- 'post': ['dereverb.ckpt']
20
  },
21
  'gtr': {
22
- 'pre': ['denoise.ckpt'],
23
- 'mss': ['gtr_mss.ckpt'],
24
  'post': []
25
  },
26
  'key': {
27
- 'pre': ['denoise.ckpt'],
28
- 'mss': ['key_mss.ckpt'],
29
  'post': []
30
  },
31
  'syn': {
32
- 'pre': ['denoise.ckpt'],
33
- 'mss': ['syn_mss.ckpt', 'syn_mss1.ckpt'],
34
  'post': []
35
  },
36
  'bass': {
37
- 'pre': ['denoise.ckpt'],
38
- 'mss': ['bass_mss.ckpt'],
39
  'post': []
40
  },
41
  'drums': {
42
- 'pre': ['denoise.ckpt'],
43
- 'mss': ['drums_mss.ckpt', 'drums_mss1.ckpt'],
44
  'post': []
45
  },
46
  'perc': {
47
- 'pre': ['denoise.ckpt'],
48
- 'mss': ['perc_mss.ckpt', 'perc_mss1.ckpt'],
49
  'post': []
50
  },
51
  'orch': {
52
- 'pre': ['denoise.ckpt'],
53
- 'mss': ['orch_mss.ckpt', 'orch_mss1.ckpt'],
54
  'post': []
55
  }
56
  }
 
14
  # Instrument to checkpoint mapping
15
  INSTRUMENT_MAP = {
16
  'vox': {
17
+ 'pre': ['denoise.pth'],
18
+ 'mss': ['vox_mss.pth'],
19
+ 'post': ['dereverb.pth']
20
  },
21
  'gtr': {
22
+ 'pre': ['denoise.pth'],
23
+ 'mss': ['gtr_mss.pth'],
24
  'post': []
25
  },
26
  'key': {
27
+ 'pre': ['denoise.pth'],
28
+ 'mss': ['key_mss.pth'],
29
  'post': []
30
  },
31
  'syn': {
32
+ 'pre': ['denoise.pth'],
33
+ 'mss': ['syn_mss.pth', 'syn_mss1.pth'],
34
  'post': []
35
  },
36
  'bass': {
37
+ 'pre': ['denoise.pth'],
38
+ 'mss': ['bass_mss.pth'],
39
  'post': []
40
  },
41
  'drums': {
42
+ 'pre': ['denoise.pth'],
43
+ 'mss': ['drums_mss.pth', 'drums_mss1.pth'],
44
  'post': []
45
  },
46
  'perc': {
47
+ 'pre': ['denoise.pth'],
48
+ 'mss': ['perc_mss.pth', 'perc_mss1.pth'],
49
  'post': []
50
  },
51
  'orch': {
52
+ 'pre': ['denoise.pth'],
53
+ 'mss': ['orch_mss.pth', 'orch_mss1.pth'],
54
  'post': []
55
  }
56
  }
inference.py CHANGED
@@ -14,6 +14,7 @@ import soundfile as sf
14
  import numpy as np
15
  from tqdm import tqdm
16
  import librosa
 
17
  from models import MelRNN, MelRoFormer, UNet, UFormer
18
  from models.bs_roformer import bs_roformer as BSRoformer
19
  from models.bs_roformer import mel_band_roformer as MelBandRoformer
@@ -55,9 +56,14 @@ class RoformerSequential(nn.Sequential):
55
  return self[-1](mixture, target) # also pass target if present
56
 
57
  def load_config_and_state_dict(path: str, map_location: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
58
- if path.endswith('.pth'):
59
- raise ValueError("Use .ckpt files instead of .pth files")
60
  print(f"Extracting state dict from {path}")
 
 
 
 
 
 
 
61
  full_checkpoint = torch.load(path, map_location=map_location, weights_only=False)
62
  full_state_dict = full_checkpoint['state_dict']
63
  generator_state_dict = OrderedDict()
@@ -119,9 +125,9 @@ def process_audio(config, audio: np.ndarray, generator: nn.Module, device: str =
119
 
120
  def main():
121
  parser = argparse.ArgumentParser(description="Run inference on audio files using trained generator")
122
- parser.add_argument("--checkpoint", '-c', type=str, required=True, help="Path to unwrapped generator weights (.ckpt)")
123
- parser.add_argument("--checkpoint_pre", '-p', type=str, help="pre-processing model checkpoint (.ckpt)")
124
- parser.add_argument("--checkpoint_post", '-P', type=str, help="post-processing model checkpoint (.ckpt)")
125
  parser.add_argument("--input_dir", '-i', type=str, help="Directory containing input .flac files")
126
  parser.add_argument("--output_dir", '-o', type=str, help="Directory to save processed audio")
127
  parser.add_argument("--instrument", type=str, help="Instrument to process (Vox/Gtr/Kbs/Synth/Bass/Rhy_DK/Rhy_PERC/Orch)")
 
14
  import numpy as np
15
  from tqdm import tqdm
16
  import librosa
17
+ import yaml
18
  from models import MelRNN, MelRoFormer, UNet, UFormer
19
  from models.bs_roformer import bs_roformer as BSRoformer
20
  from models.bs_roformer import mel_band_roformer as MelBandRoformer
 
56
  return self[-1](mixture, target) # also pass target if present
57
 
58
  def load_config_and_state_dict(path: str, map_location: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
 
 
59
  print(f"Extracting state dict from {path}")
60
+ if path.endswith('.pth'):
61
+ model_name = Path(path).stem
62
+ config_path = f"./configs/{model_name}.yaml" # use config file with same name as model in ./configs
63
+ print(f"Loading config from {config_path}")
64
+ with open(config_path, 'r') as f:
65
+ config = yaml.load(f, Loader=yaml.FullLoader)
66
+ return config, torch.load(path, map_location=map_location)
67
  full_checkpoint = torch.load(path, map_location=map_location, weights_only=False)
68
  full_state_dict = full_checkpoint['state_dict']
69
  generator_state_dict = OrderedDict()
 
125
 
126
  def main():
127
  parser = argparse.ArgumentParser(description="Run inference on audio files using trained generator")
128
+ parser.add_argument("--checkpoint", '-c', type=str, required=True, help="Path to unwrapped generator weights (.ckpt or .pth)")
129
+ parser.add_argument("--checkpoint_pre", '-p', type=str, help="pre-processing model checkpoint (.ckpt or .pth)")
130
+ parser.add_argument("--checkpoint_post", '-P', type=str, help="post-processing model checkpoint (.ckpt or .pth)")
131
  parser.add_argument("--input_dir", '-i', type=str, help="Directory containing input .flac files")
132
  parser.add_argument("--output_dir", '-o', type=str, help="Directory to save processed audio")
133
  parser.add_argument("--instrument", type=str, help="Instrument to process (Vox/Gtr/Kbs/Synth/Bass/Rhy_DK/Rhy_PERC/Orch)")
inference_full.py CHANGED
@@ -171,9 +171,9 @@ def inference_main(args):
171
 
172
  if __name__ == '__main__':
173
  parser = argparse.ArgumentParser(description="Run inference on audio files using trained generator")
174
- parser.add_argument("--checkpoint", '-c', nargs='*', default=[], type=str, help="model checkpoint (.ckpt)")
175
- parser.add_argument("--checkpoint_pre", '-p', nargs='*', default=[], type=str, help="pre-processing model checkpoint (.ckpt)")
176
- parser.add_argument("--checkpoint_post", '-P', nargs='*', default=[], type=str, help="post-processing model checkpoint (.ckpt)")
177
  parser.add_argument("--input_dir", '-i', type=str, help="Directory containing input files, or a single audio file")
178
  parser.add_argument("--output_dir", '-o', type=str, help="Directory to save processed audio, or a single audio file name")
179
  parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (cuda/cpu)")
 
171
 
172
  if __name__ == '__main__':
173
  parser = argparse.ArgumentParser(description="Run inference on audio files using trained generator")
174
+ parser.add_argument("--checkpoint", '-c', nargs='*', default=[], type=str, help="model checkpoint (.ckpt or .pth)")
175
+ parser.add_argument("--checkpoint_pre", '-p', nargs='*', default=[], type=str, help="pre-processing model checkpoint (.ckpt or .pth)")
176
+ parser.add_argument("--checkpoint_post", '-P', nargs='*', default=[], type=str, help="post-processing model checkpoint (.ckpt or .pth)")
177
  parser.add_argument("--input_dir", '-i', type=str, help="Directory containing input files, or a single audio file")
178
  parser.add_argument("--output_dir", '-o', type=str, help="Directory to save processed audio, or a single audio file name")
179
  parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (cuda/cpu)")
start_script.sh CHANGED
@@ -1,8 +1,8 @@
1
  # inference_full (for inference on full audio)
2
  python inference_full.py \
3
- -p checkpoints/denoise.ckpt \
4
- -c checkpoints/vox_mss.ckpt \
5
- -P checkpoints/dereverb.ckpt \ # only vox need it
6
  -i test/input \
7
  -o test/output \
8
  --device cuda \
@@ -10,8 +10,8 @@ python inference_full.py \
10
 
11
  # another example (eight instruments are supported in total)
12
  python inference_full.py \
13
- -p checkpoints/denoise.ckpt \
14
- -c checkpoints/drums_mss.ckpt checkpoints/drums_mss1.ckpt \ # use multiple checkpoints for some instruments
15
  -i test/input \
16
  -o test/output \
17
  --device cuda \
@@ -39,4 +39,7 @@ python inference.py \
39
  -P logs/dereverb/vox/checkpoints/00000010.ckpt \
40
  -i OrganizersMixture/Vocals \
41
  -o Answer/Vocals/default \
42
- --no-eval
 
 
 
 
1
  # inference_full (for inference on full audio)
2
  python inference_full.py \
3
+ -p checkpoints/denoise.pth \
4
+ -c checkpoints/vox_mss.pth \
5
+ -P checkpoints/dereverb.pth \ # only vox need it
6
  -i test/input \
7
  -o test/output \
8
  --device cuda \
 
10
 
11
  # another example (eight instruments are supported in total)
12
  python inference_full.py \
13
+ -p checkpoints/denoise.pth \
14
+ -c checkpoints/drums_mss.pth checkpoints/drums_mss1.pth \ # use multiple checkpoints for some instruments
15
  -i test/input \
16
  -o test/output \
17
  --device cuda \
 
39
  -P logs/dereverb/vox/checkpoints/00000010.ckpt \
40
  -i OrganizersMixture/Vocals \
41
  -o Answer/Vocals/default \
42
+ --no-eval
43
+
44
+ # unwrap
45
+ python unwrap.py -i checkpoints/bass_mss.ckpt -o checkpoints/bass_mss.pth