Spaces:
Running on Zero
Running on Zero
| # Demo for inference on full audio. Only support the final model. | |
| # Process like ans.py but accept parameters like inference.py. | |
| import argparse | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import soundfile as sf | |
| import librosa | |
| import torch.nn.functional as F | |
| from torch.utils.data._utils.collate import default_collate | |
| from tqdm import tqdm | |
| from inference import load_config_and_state_dict, load_generator, process_audio | |
| def demix_single( | |
| model: nn.Module, | |
| mix: torch.Tensor, # (b, c, l) | |
| ) -> torch.Tensor: | |
| device = next(model.parameters()).device | |
| mix = mix.float().to(device) | |
| with torch.autocast(device_type='cuda', dtype=torch.float16): | |
| with torch.no_grad(): | |
| output = model(mix) | |
| output = output.cpu() | |
| return F.pad(output, (0, mix.shape[-1]-output.shape[-1]), mode='constant', value=0) | |
| def split_audio( | |
| audio: np.ndarray, | |
| chunk_size: int, | |
| overlap_size: int, | |
| ) -> list[np.ndarray]: | |
| """Split audio into overlapping chunks.""" | |
| hop_size = chunk_size - overlap_size | |
| chunks = [] | |
| start = 0 | |
| while start + chunk_size <= audio.shape[-1]: | |
| chunks.append(audio[..., start:start + chunk_size]) | |
| start += hop_size | |
| return chunks | |
| def merge_chunks( | |
| chunks: list[np.ndarray], | |
| chunk_size: int, | |
| overlap_size: int, | |
| ) -> np.ndarray: | |
| """Merge chunks with overlap-add.""" | |
| hop_size = chunk_size - overlap_size | |
| output = np.zeros((chunks[0].shape[0], hop_size * len(chunks) + overlap_size)) | |
| for i, chunk in enumerate(chunks): | |
| start = i * hop_size | |
| end = start + chunk_size | |
| window = np.ones_like(chunk) | |
| if overlap_size > 0: | |
| fade_in = np.linspace(0, 1, overlap_size) | |
| fade_out = np.linspace(1, 0, overlap_size) | |
| window[..., :overlap_size] *= fade_in | |
| window[..., -overlap_size:] *= fade_out | |
| output[..., start:end] += chunk * window | |
| return output | |
| def process_long_audio( | |
| model: nn.Module, | |
| mix: np.ndarray, # (c, l) | |
| sr: int, | |
| chunk_duration: float, | |
| overlap: float, | |
| batch_size: int, | |
| ) -> np.ndarray: | |
| chunk_size = int(chunk_duration * sr) | |
| overlap_size = int(overlap * sr) | |
| hop_size = chunk_size - overlap_size | |
| l = mix.shape[-1] | |
| l_new = ((l - overlap_size + hop_size - 1) // hop_size) * hop_size + overlap_size | |
| padding_shape = (mix.shape[0], l_new - l) | |
| padding = np.zeros(padding_shape, dtype=mix.dtype) | |
| mix = np.concatenate([mix, padding], axis=-1) | |
| print(f"Processing long audio of length {mix.shape[-1]} samples") | |
| chunks = split_audio(mix, chunk_size, overlap_size) | |
| batched_chunks = [chunks[i:i + batch_size] for i in range(0, len(chunks), batch_size)] | |
| processed_chunks = [] | |
| for batch in batched_chunks: | |
| print(f"Processing chunks {len(processed_chunks) + 1}/{len(chunks)}") | |
| tensor = default_collate(batch) # Add batch dim | |
| processed = demix_single(model, tensor) | |
| processed_chunks.extend([processed[i].numpy() for i in range(processed.shape[0])]) # Remove batch dim | |
| merged = merge_chunks(processed_chunks, chunk_size, overlap_size) | |
| return merged[..., :l] | |
| def inference(models, audio, sr, batch_size): | |
| # audio: (channels, samples) | |
| channels, samples = audio.shape | |
| for (config, model) in models: | |
| model_sr = config['data']['sample_rate'] | |
| if sr != model_sr: | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=model_sr) | |
| audio = process_long_audio(model, audio, sr, chunk_duration=10.0, overlap=1.0, batch_size=batch_size) | |
| if sr != model_sr: | |
| audio = librosa.resample(audio, model_sr, sr) | |
| if samples < audio.shape[1]: | |
| audio = audio[:, :samples] | |
| if samples > audio.shape[1]: | |
| audio = np.pad(audio, ((0, 0), (0, samples - audio.shape[1])), 'constant') | |
| return audio | |
| def load_models(paths, device): | |
| models = [] | |
| for path in paths: | |
| config, state_dict = load_config_and_state_dict(path, device) | |
| model = load_generator(config, state_dict, device=device) | |
| models.append((config, model)) | |
| return models | |
| def load_audio(input_path): | |
| audio, sr = sf.read(input_path) | |
| if audio.ndim == 1: # mono audio, convert to stereo | |
| audio = np.stack([audio, audio], axis=0) | |
| else: | |
| audio = audio.T | |
| return audio, sr | |
| def save_audio(audio, sr, output_path): | |
| audio = audio.T | |
| sf.write(output_path, audio, sr) | |
| def calculate_rms(audio): | |
| rms = np.sqrt(np.mean(audio**2)) | |
| rms_db = 20 * np.log10(rms + 1e-10) | |
| return rms_db | |
| def inference_main(args): | |
| pre_models = load_models(args.checkpoint_pre, args.device) | |
| mss_models = load_models(args.checkpoint, args.device) | |
| post_models = load_models(args.checkpoint_post, args.device) | |
| input_dir = Path(args.input_dir) | |
| output_dir = Path(args.output_dir) | |
| if input_dir.is_dir(): | |
| # Get all audio files | |
| audio_files = sorted(input_dir.glob("*.flac")) + sorted(input_dir.glob("*.wav")) + sorted(input_dir.glob("*.mp3")) | |
| print(f"Found {len(audio_files)} audio files") | |
| else: | |
| audio_files = [input_dir] | |
| for audio_file in tqdm(audio_files, desc="Processing audio files"): | |
| audio, sr = load_audio(audio_file) | |
| print("Processing audio file:", audio_file) | |
| audio = inference(pre_models, audio, sr, batch_size=args.batch_size) | |
| audio = inference(mss_models, audio, sr, batch_size=args.batch_size) | |
| rms = calculate_rms(audio) | |
| print("RMS of MSS audio:", rms) | |
| audio_dereverb = inference(post_models, audio, sr, batch_size=args.batch_size) | |
| rms_dereverb = calculate_rms(audio_dereverb) | |
| print("RMS of dereverbed audio:", rms_dereverb) | |
| if rms - rms_dereverb > 10.0: | |
| print("Dereverb audio is too quiet, use original") | |
| else: | |
| audio = audio_dereverb | |
| output_path = output_dir / audio_file.name if input_dir.is_dir() else output_dir # corresponding to input_dir.is_dir() | |
| if input_dir.is_dir(): | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| else: | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| save_audio(audio, sr, output_path) | |
| print("Final result saved to:", output_path) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description="Run inference on audio files using trained generator") | |
| parser.add_argument("--checkpoint", '-c', nargs='*', default=[], type=str, help="model checkpoint (.ckpt or .pth)") | |
| parser.add_argument("--checkpoint_pre", '-p', nargs='*', default=[], type=str, help="pre-processing model checkpoint (.ckpt or .pth)") | |
| parser.add_argument("--checkpoint_post", '-P', nargs='*', default=[], type=str, help="post-processing model checkpoint (.ckpt or .pth)") | |
| parser.add_argument("--input_dir", '-i', type=str, help="Directory containing input files, or a single audio file") | |
| parser.add_argument("--output_dir", '-o', type=str, help="Directory to save processed audio, or a single audio file name") | |
| parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (cuda/cpu)") | |
| parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference") | |
| args = parser.parse_args() | |
| inference_main(args) | |