File size: 5,768 Bytes
ddadeb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import torch
import torchaudio
import os
import gc
def frequency_blend_phases(phase1, phase2, freq_bins, low_cutoff=500, high_cutoff=5000, base_factor=0.25, scale_factor=1.85):
if phase1.shape != phase2.shape:
raise ValueError("phase1 and phase2 must have the same shape.")
if len(freq_bins) != phase1.shape[0]:
raise ValueError("freq_bins must have the same length as the number of frequency bins in phase1 and phase2.")
if low_cutoff >= high_cutoff:
raise ValueError("low_cutoff must be less than high_cutoff.")
blended_phase = torch.zeros_like(phase1)
blend_factors = torch.zeros_like(freq_bins)
blend_factors[freq_bins < low_cutoff] = base_factor
blend_factors[freq_bins > high_cutoff] = base_factor + scale_factor
in_range_mask = (freq_bins >= low_cutoff) & (freq_bins <= high_cutoff)
blend_factors[in_range_mask] = base_factor + scale_factor * (
(freq_bins[in_range_mask] - low_cutoff) / (high_cutoff - low_cutoff)
)
for i in range(phase1.shape[0]):
blended_phase[i, :] = (1 - blend_factors[i]) * phase1[i, :] + blend_factors[i] * phase2[i, :]
blended_phase = torch.remainder(blended_phase + torch.pi, 2 * torch.pi) - torch.pi
return blended_phase
def transfer_magnitude_phase(source_file, target_file, output_folder, transfer_magnitude=False, transfer_phase=True,
low_cutoff=500, high_cutoff=9000, scale_factor=1.4, output_format='flac'):
target_name, target_ext = os.path.splitext(os.path.basename(target_file))
target_name = target_name.replace("_other", "").replace("_vocals", "").replace("_instrumental", "")
target_name = target_name.replace("_Other", "").replace("_Vocals", "").replace("_Instrumental", "").strip()
ext = '.flac' if output_format == 'flac' else '.wav'
output_file = os.path.join(output_folder, f"{target_name} (Fixed Instrumental){ext}")
print(f"Phase Fixing: {os.path.basename(target_file)}...")
source_waveform, source_sr = torchaudio.load(source_file)
target_waveform, target_sr = torchaudio.load(target_file)
if source_sr != target_sr:
raise ValueError("Sample rates of source and target audio files must match.")
n_fft = 2048
hop_length = 512
window = torch.hann_window(n_fft)
source_stfts = torch.stft(source_waveform, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True, pad_mode="reflect")
target_stfts = torch.stft(target_waveform, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True, pad_mode="reflect")
freqs = torch.linspace(0, source_sr // 2, steps=n_fft // 2 + 1)
modified_stfts = []
for source_stft, target_stft in zip(source_stfts, target_stfts):
source_mag, source_phs = torch.abs(source_stft), torch.angle(source_stft)
target_mag, target_phs = torch.abs(target_stft), torch.angle(target_stft)
modified_stft = target_stft.clone()
if transfer_magnitude:
modified_stft = source_mag * torch.exp(1j * torch.angle(modified_stft))
if transfer_phase:
blended_phase = frequency_blend_phases(target_phs, source_phs, freqs, low_cutoff, high_cutoff, scale_factor=scale_factor)
modified_stft = torch.abs(modified_stft) * torch.exp(1j * blended_phase)
modified_stfts.append(modified_stft)
modified_waveform = torch.istft(
torch.stack(modified_stfts),
n_fft=n_fft,
hop_length=hop_length,
window=window,
length=source_waveform.size(1)
)
if output_format == 'flac':
torchaudio.save(output_file, modified_waveform, target_sr, format="flac", bits_per_sample=16)
else:
torchaudio.save(output_file, modified_waveform, target_sr)
print(f"Saved: {output_file}")
return output_file
def process_phase_fix(source_file, target_file, output_folder, low_cutoff=500, high_cutoff=9000,
scale_factor=1.4, output_format='flac'):
os.makedirs(output_folder, exist_ok=True)
try:
output_file = transfer_magnitude_phase(
source_file=source_file,
target_file=target_file,
output_folder=output_folder,
transfer_magnitude=False,
transfer_phase=True,
low_cutoff=low_cutoff,
high_cutoff=high_cutoff,
scale_factor=scale_factor,
output_format=output_format
)
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return output_file, "Phase fix completed successfully!"
except Exception as e:
return None, f"Error during phase fix: {str(e)}"
SOURCE_MODELS = [
'VOCALS-MelBand-Roformer (by Becruily)',
'VOCALS-Mel-Roformer big beta 4 (by unwa)',
'VOCALS-Melband-Roformer BigBeta5e (by unwa)',
'VOCALS-big_beta6 (by Unwa)',
'VOCALS-big_beta6X (by Unwa)',
'VOCALS-MelBand-Roformer (by KimberleyJSN)',
'VOCALS-MelBand-Roformer Kim FT (by Unwa)',
'VOCALS-MelBand-Roformer Kim FT 2 (by Unwa)',
'VOCALS-MelBand-Roformer Kim FT 2 Blendless (by unwa)',
'VOCALS-Mel-Roformer FT 3 Preview (by unwa)',
'VOCALS-BS-Roformer_1296 (by viperx)',
'VOCALS-BS-Roformer_1297 (by viperx)',
'VOCALS-BS-RoformerLargev1 (by unwa)',
'bs_roformer_revive (by unwa)'
]
TARGET_MODELS = [
'INST-MelBand-Roformer (by Becruily)',
'INST-Mel-Roformer v1 (by unwa)',
'INST-Mel-Roformer v2 (by unwa)',
'inst_v1e (by unwa)',
'INST-Mel-Roformer v1e+ (by unwa)',
'Inst_GaboxV7 (by Gabox)',
'INST-VOC-Mel-Roformer a.k.a. duality (by unwa)',
'INST-VOC-Mel-Roformer a.k.a. duality v2 (by unwa)',
'inst_gabox (by Gabox)',
'inst_gaboxFlowersV10 (by Gabox)'
]
|