xlance-msr / inference.py
Jihuai's picture
use pth rather than ckpt
fd56f2f
import argparse
from collections import OrderedDict
import copy
import os
from pathlib import Path
import re
import subprocess
import sys
from typing import Dict, Any, Tuple
import torch
import torch.nn as nn
import soundfile as sf
import numpy as np
from tqdm import tqdm
import librosa
import yaml
from models import MelRNN, MelRoFormer, UNet, UFormer
from models.bs_roformer import bs_roformer as BSRoformer
from models.bs_roformer import mel_band_roformer as MelBandRoformer
RAWSTEMS_TO_MSRBENCH = {
'Voc': 'vox',
'Gtr': 'gtr',
'Kbs': 'key',
'Synth': 'syn',
'Bass': 'bass',
'Rhy_DK': 'drums',
'Rhy_PERC': 'perc',
'Orch': 'orch',
}
def init_generator(model_cfg):
if model_cfg['name'] == 'MelRNN':
return MelRNN.MelRNN(**model_cfg['params'])
elif model_cfg['name'] == 'MelRoFormer':
return MelRoFormer.MelRoFormer(**model_cfg['params'])
elif model_cfg['name'] == 'MelUNet':
return UNet.MelUNet(**model_cfg['params'])
elif model_cfg['name'] == 'UFormer':
return UFormer.UFormer(UFormer.UFormerConfig(**model_cfg['params']))
elif model_cfg['name'] == 'BSRoFormer':
return BSRoformer.BSRoformer(**model_cfg['params'])
elif model_cfg['name'] == 'MelBandRoformer':
return MelBandRoformer.MelBandRoformer(**model_cfg['params'])
else:
raise ValueError(f"Unknown model name: {model_cfg['name']}")
class RoformerSequential(nn.Sequential):
def __init__(self, *args):
super().__init__(*args)
def forward(self, mixture, target=None):
for module in self[:-1]:
mixture = module(mixture) # only pass mixture
return self[-1](mixture, target) # also pass target if present
def load_config_and_state_dict(path: str, map_location: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
print(f"Extracting state dict from {path}")
if path.endswith('.pth'):
model_name = Path(path).stem
config_path = f"./configs/{model_name}.yaml" # use config file with same name as model in ./configs
print(f"Loading config from {config_path}")
with open(config_path, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config, torch.load(path, map_location=map_location)
full_checkpoint = torch.load(path, map_location=map_location, weights_only=False)
full_state_dict = full_checkpoint['state_dict']
generator_state_dict = OrderedDict()
prefix = 'generator.'
prefix_len = len(prefix)
for key, value in full_state_dict.items():
if key.startswith(prefix):
new_key = key[prefix_len:]
generator_state_dict[new_key] = value
return full_checkpoint['hyper_parameters'], generator_state_dict
def load_generator(config: Dict[str, Any], state_dict: Dict[str, Any], device: str = 'cuda') -> nn.Module:
"""Initialize and load the generator model from unwrapped checkpoint."""
generator = init_generator(config['model'])
if 'model1' in config:
generator1 = copy.deepcopy(generator)
generator = RoformerSequential(generator, generator1)
# Load unwrapped generator weights
generator.load_state_dict(state_dict)
generator = generator.to(device)
generator.eval()
return generator
def process_audio(config, audio: np.ndarray, generator: nn.Module, device: str = 'cuda') -> np.ndarray:
use_channel = config['model']['name'] in ['BSRoFormer', 'UFormer', 'MelBandRoformer']
use_16_mix = config['model']['name'] in ['BSRoFormer', 'MelBandRoformer']
"""Process a single audio array through the generator."""
# Convert to tensor: (channels, samples) -> (1, channels, samples)
if audio.ndim == 1:
audio = audio[np.newaxis, :] # Add channel dimension for mono
audio_tensor = torch.from_numpy(audio).float().to(device)
if use_channel:
audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension
if use_16_mix:
with torch.autocast(device_type='cuda', dtype=torch.float16):
with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]):
with torch.no_grad():
output_tensor = generator(audio_tensor)
else:
with torch.no_grad():
output_tensor = generator(audio_tensor)
# Convert back to numpy: (1, channels, samples) -> (channels, samples)
output_audio = output_tensor.cpu().numpy()
if use_channel:
output_audio = output_audio[0] # Remove batch dimension
return output_audio
def main():
parser = argparse.ArgumentParser(description="Run inference on audio files using trained generator")
parser.add_argument("--checkpoint", '-c', type=str, required=True, help="Path to unwrapped generator weights (.ckpt or .pth)")
parser.add_argument("--checkpoint_pre", '-p', type=str, help="pre-processing model checkpoint (.ckpt or .pth)")
parser.add_argument("--checkpoint_post", '-P', type=str, help="post-processing model checkpoint (.ckpt or .pth)")
parser.add_argument("--input_dir", '-i', type=str, help="Directory containing input .flac files")
parser.add_argument("--output_dir", '-o', type=str, help="Directory to save processed audio")
parser.add_argument("--instrument", type=str, help="Instrument to process (Vox/Gtr/Kbs/Synth/Bass/Rhy_DK/Rhy_PERC/Orch)")
parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (cuda/cpu)")
parser.add_argument("--no-eval", action="store_false", dest="eval", help="Skip evaluation after inference")
parser.add_argument("--target_index", type=str, help="Index of target audio files, e.g. '11|12'")
args = parser.parse_args()
config, state_dict = load_config_and_state_dict(args.checkpoint, args.device)
project_name = config['project_name']
exp_name = config['exp_name']
step = Path(args.checkpoint).stem
instrument = RAWSTEMS_TO_MSRBENCH[config['data']['val_dataset']['target_stem']].capitalize() if args.instrument is None else args.instrument.capitalize()
print(f"Project: {project_name}, Exp: {exp_name}, Step: {step}, Instrument: {instrument}")
if not args.input_dir:
args.input_dir = f"../../data/MSRBench/{instrument}/mixture/"
print(f"No input directory specified, using default: {args.input_dir}")
if not args.output_dir:
args.output_dir = f"output/{project_name}/{exp_name}_{step}/"
print(f"No output directory specified, using default: {args.output_dir}")
# Setup paths
input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=False)
# Get all audio files
audio_files = sorted(input_dir.glob("*.flac")) + sorted(input_dir.glob("*.wav"))
if args.target_index is not None:
regex = re.compile(rf"^\d+_DT({args.target_index})\.\w+$")
else:
regex = re.compile(rf"^.*\.\w+$")
audio_files = [f for f in audio_files if regex.match(os.path.basename(f))]
audio_files.sort()
if len(audio_files) == 0:
print(f"No .flac or .wav files found in {input_dir}")
return
print(f"Found {len(audio_files)} audio files")
generators = []
if args.checkpoint_pre:
print(f"Loading pre-processing model from {args.checkpoint_pre}...")
config_pre, state_dict_pre = load_config_and_state_dict(args.checkpoint_pre, args.device)
generator_pre = load_generator(config_pre, state_dict_pre, device=args.device)
generators.append((generator_pre, "_pre"))
# Load model
print(f"Loading generator from {args.checkpoint}...")
generator = load_generator(config, state_dict, device=args.device)
generators.append((generator, "_sep" if args.checkpoint_post else ""))
if args.checkpoint_post:
print(f"Loading post-processing model from {args.checkpoint_post}...")
config_post, state_dict_post = load_config_and_state_dict(args.checkpoint_post, args.device)
generator_post = load_generator(config_post, state_dict_post, device=args.device)
generators.append((generator_post, ""))
# Process each file
for audio_file in tqdm(audio_files, desc="Processing audio files"):
input_path = audio_file
for generator, postfix in generators:
# Load audio
audio, sr = sf.read(input_path)
model_sr = config['data']['sample_rate']
# Transpose if needed: soundfile loads as (samples, channels)
if audio.ndim == 2:
audio = audio.T # Convert to (channels, samples)
if sr != model_sr:
audio = librosa.resample(audio, sr, model_sr)
# Process through generator
output_audio = process_audio(config, audio, generator, device=args.device)
if sr != model_sr:
output_audio = librosa.resample(output_audio, model_sr, sr)
# Transpose back for saving: (channels, samples) -> (samples, channels)
if output_audio.ndim == 2:
output_audio = output_audio.T
# Save with same filename
output_path = output_dir / ((audio_file.stem + postfix) + audio_file.suffix)
sf.write(output_path, output_audio, sr)
input_path = output_path
print(f"\nProcessing complete! Output saved to {output_dir}")
if args.eval:
current_dir = os.path.dirname(os.path.abspath(__file__))
program2_path = os.path.join(current_dir, "eval_plus.py")
cmd = [sys.executable, program2_path]
arg_eval = {
'--target_dir': f"../../data/MSRBench/{instrument}/target/",
'--output_dir': args.output_dir,
'--target_index': args.target_index,
}
for key, value in arg_eval.items():
if value is None:
continue
cmd.extend([key, str(value)])
subprocess.run(cmd)
if __name__ == '__main__':
main()