| import torch |
| import torchaudio |
| import os |
| import gc |
|
|
| def frequency_blend_phases(phase1, phase2, freq_bins, low_cutoff=500, high_cutoff=5000, base_factor=0.25, scale_factor=1.85): |
| if phase1.shape != phase2.shape: |
| raise ValueError("phase1 and phase2 must have the same shape.") |
| if len(freq_bins) != phase1.shape[0]: |
| raise ValueError("freq_bins must have the same length as the number of frequency bins in phase1 and phase2.") |
| if low_cutoff >= high_cutoff: |
| raise ValueError("low_cutoff must be less than high_cutoff.") |
|
|
| blended_phase = torch.zeros_like(phase1) |
| blend_factors = torch.zeros_like(freq_bins) |
|
|
| blend_factors[freq_bins < low_cutoff] = base_factor |
| blend_factors[freq_bins > high_cutoff] = base_factor + scale_factor |
|
|
| in_range_mask = (freq_bins >= low_cutoff) & (freq_bins <= high_cutoff) |
| blend_factors[in_range_mask] = base_factor + scale_factor * ( |
| (freq_bins[in_range_mask] - low_cutoff) / (high_cutoff - low_cutoff) |
| ) |
|
|
| for i in range(phase1.shape[0]): |
| blended_phase[i, :] = (1 - blend_factors[i]) * phase1[i, :] + blend_factors[i] * phase2[i, :] |
|
|
| blended_phase = torch.remainder(blended_phase + torch.pi, 2 * torch.pi) - torch.pi |
|
|
| return blended_phase |
|
|
| def transfer_magnitude_phase(source_file, target_file, output_folder, transfer_magnitude=False, transfer_phase=True, |
| low_cutoff=500, high_cutoff=9000, scale_factor=1.4, output_format='flac'): |
| target_name, target_ext = os.path.splitext(os.path.basename(target_file)) |
| |
| target_name = target_name.replace("_other", "").replace("_vocals", "").replace("_instrumental", "") |
| target_name = target_name.replace("_Other", "").replace("_Vocals", "").replace("_Instrumental", "").strip() |
| |
| ext = '.flac' if output_format == 'flac' else '.wav' |
| output_file = os.path.join(output_folder, f"{target_name} (Fixed Instrumental){ext}") |
|
|
| print(f"Phase Fixing: {os.path.basename(target_file)}...") |
| source_waveform, source_sr = torchaudio.load(source_file) |
| target_waveform, target_sr = torchaudio.load(target_file) |
|
|
| if source_sr != target_sr: |
| raise ValueError("Sample rates of source and target audio files must match.") |
|
|
| n_fft = 2048 |
| hop_length = 512 |
| window = torch.hann_window(n_fft) |
|
|
| source_stfts = torch.stft(source_waveform, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True, pad_mode="reflect") |
| target_stfts = torch.stft(target_waveform, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True, pad_mode="reflect") |
|
|
| freqs = torch.linspace(0, source_sr // 2, steps=n_fft // 2 + 1) |
|
|
| modified_stfts = [] |
| for source_stft, target_stft in zip(source_stfts, target_stfts): |
| source_mag, source_phs = torch.abs(source_stft), torch.angle(source_stft) |
| target_mag, target_phs = torch.abs(target_stft), torch.angle(target_stft) |
|
|
| modified_stft = target_stft.clone() |
| if transfer_magnitude: |
| modified_stft = source_mag * torch.exp(1j * torch.angle(modified_stft)) |
|
|
| if transfer_phase: |
| blended_phase = frequency_blend_phases(target_phs, source_phs, freqs, low_cutoff, high_cutoff, scale_factor=scale_factor) |
| modified_stft = torch.abs(modified_stft) * torch.exp(1j * blended_phase) |
|
|
| modified_stfts.append(modified_stft) |
|
|
| modified_waveform = torch.istft( |
| torch.stack(modified_stfts), |
| n_fft=n_fft, |
| hop_length=hop_length, |
| window=window, |
| length=source_waveform.size(1) |
| ) |
|
|
| if output_format == 'flac': |
| torchaudio.save(output_file, modified_waveform, target_sr, format="flac", bits_per_sample=16) |
| else: |
| torchaudio.save(output_file, modified_waveform, target_sr) |
| |
| print(f"Saved: {output_file}") |
| return output_file |
|
|
| def process_phase_fix(source_file, target_file, output_folder, low_cutoff=500, high_cutoff=9000, |
| scale_factor=1.4, output_format='flac'): |
| os.makedirs(output_folder, exist_ok=True) |
| |
| try: |
| output_file = transfer_magnitude_phase( |
| source_file=source_file, |
| target_file=target_file, |
| output_folder=output_folder, |
| transfer_magnitude=False, |
| transfer_phase=True, |
| low_cutoff=low_cutoff, |
| high_cutoff=high_cutoff, |
| scale_factor=scale_factor, |
| output_format=output_format |
| ) |
| gc.collect() |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| return output_file, "Phase fix completed successfully!" |
| except Exception as e: |
| return None, f"Error during phase fix: {str(e)}" |
|
|
| SOURCE_MODELS = [ |
| 'VOCALS-MelBand-Roformer (by Becruily)', |
| 'VOCALS-Mel-Roformer big beta 4 (by unwa)', |
| 'VOCALS-Melband-Roformer BigBeta5e (by unwa)', |
| 'VOCALS-big_beta6 (by Unwa)', |
| 'VOCALS-big_beta6X (by Unwa)', |
| 'VOCALS-MelBand-Roformer (by KimberleyJSN)', |
| 'VOCALS-MelBand-Roformer Kim FT (by Unwa)', |
| 'VOCALS-MelBand-Roformer Kim FT 2 (by Unwa)', |
| 'VOCALS-MelBand-Roformer Kim FT 2 Blendless (by unwa)', |
| 'VOCALS-Mel-Roformer FT 3 Preview (by unwa)', |
| 'VOCALS-BS-Roformer_1296 (by viperx)', |
| 'VOCALS-BS-Roformer_1297 (by viperx)', |
| 'VOCALS-BS-RoformerLargev1 (by unwa)', |
| 'bs_roformer_revive (by unwa)' |
| ] |
|
|
| TARGET_MODELS = [ |
| 'INST-MelBand-Roformer (by Becruily)', |
| 'INST-Mel-Roformer v1 (by unwa)', |
| 'INST-Mel-Roformer v2 (by unwa)', |
| 'inst_v1e (by unwa)', |
| 'INST-Mel-Roformer v1e+ (by unwa)', |
| 'Inst_GaboxV7 (by Gabox)', |
| 'INST-VOC-Mel-Roformer a.k.a. duality (by unwa)', |
| 'INST-VOC-Mel-Roformer a.k.a. duality v2 (by unwa)', |
| 'inst_gabox (by Gabox)', |
| 'inst_gaboxFlowersV10 (by Gabox)' |
| ] |
|
|