SESA_Audio_Separation / phase_fixer.py
ASesYusuf1's picture
SESA: GitHub'dan güncel dosyalar aktarıldı - 2026-03-18 23:04
ddadeb4
import torch
import torchaudio
import os
import gc
def frequency_blend_phases(phase1, phase2, freq_bins, low_cutoff=500, high_cutoff=5000, base_factor=0.25, scale_factor=1.85):
if phase1.shape != phase2.shape:
raise ValueError("phase1 and phase2 must have the same shape.")
if len(freq_bins) != phase1.shape[0]:
raise ValueError("freq_bins must have the same length as the number of frequency bins in phase1 and phase2.")
if low_cutoff >= high_cutoff:
raise ValueError("low_cutoff must be less than high_cutoff.")
blended_phase = torch.zeros_like(phase1)
blend_factors = torch.zeros_like(freq_bins)
blend_factors[freq_bins < low_cutoff] = base_factor
blend_factors[freq_bins > high_cutoff] = base_factor + scale_factor
in_range_mask = (freq_bins >= low_cutoff) & (freq_bins <= high_cutoff)
blend_factors[in_range_mask] = base_factor + scale_factor * (
(freq_bins[in_range_mask] - low_cutoff) / (high_cutoff - low_cutoff)
)
for i in range(phase1.shape[0]):
blended_phase[i, :] = (1 - blend_factors[i]) * phase1[i, :] + blend_factors[i] * phase2[i, :]
blended_phase = torch.remainder(blended_phase + torch.pi, 2 * torch.pi) - torch.pi
return blended_phase
def transfer_magnitude_phase(source_file, target_file, output_folder, transfer_magnitude=False, transfer_phase=True,
low_cutoff=500, high_cutoff=9000, scale_factor=1.4, output_format='flac'):
target_name, target_ext = os.path.splitext(os.path.basename(target_file))
target_name = target_name.replace("_other", "").replace("_vocals", "").replace("_instrumental", "")
target_name = target_name.replace("_Other", "").replace("_Vocals", "").replace("_Instrumental", "").strip()
ext = '.flac' if output_format == 'flac' else '.wav'
output_file = os.path.join(output_folder, f"{target_name} (Fixed Instrumental){ext}")
print(f"Phase Fixing: {os.path.basename(target_file)}...")
source_waveform, source_sr = torchaudio.load(source_file)
target_waveform, target_sr = torchaudio.load(target_file)
if source_sr != target_sr:
raise ValueError("Sample rates of source and target audio files must match.")
n_fft = 2048
hop_length = 512
window = torch.hann_window(n_fft)
source_stfts = torch.stft(source_waveform, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True, pad_mode="reflect")
target_stfts = torch.stft(target_waveform, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True, pad_mode="reflect")
freqs = torch.linspace(0, source_sr // 2, steps=n_fft // 2 + 1)
modified_stfts = []
for source_stft, target_stft in zip(source_stfts, target_stfts):
source_mag, source_phs = torch.abs(source_stft), torch.angle(source_stft)
target_mag, target_phs = torch.abs(target_stft), torch.angle(target_stft)
modified_stft = target_stft.clone()
if transfer_magnitude:
modified_stft = source_mag * torch.exp(1j * torch.angle(modified_stft))
if transfer_phase:
blended_phase = frequency_blend_phases(target_phs, source_phs, freqs, low_cutoff, high_cutoff, scale_factor=scale_factor)
modified_stft = torch.abs(modified_stft) * torch.exp(1j * blended_phase)
modified_stfts.append(modified_stft)
modified_waveform = torch.istft(
torch.stack(modified_stfts),
n_fft=n_fft,
hop_length=hop_length,
window=window,
length=source_waveform.size(1)
)
if output_format == 'flac':
torchaudio.save(output_file, modified_waveform, target_sr, format="flac", bits_per_sample=16)
else:
torchaudio.save(output_file, modified_waveform, target_sr)
print(f"Saved: {output_file}")
return output_file
def process_phase_fix(source_file, target_file, output_folder, low_cutoff=500, high_cutoff=9000,
scale_factor=1.4, output_format='flac'):
os.makedirs(output_folder, exist_ok=True)
try:
output_file = transfer_magnitude_phase(
source_file=source_file,
target_file=target_file,
output_folder=output_folder,
transfer_magnitude=False,
transfer_phase=True,
low_cutoff=low_cutoff,
high_cutoff=high_cutoff,
scale_factor=scale_factor,
output_format=output_format
)
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return output_file, "Phase fix completed successfully!"
except Exception as e:
return None, f"Error during phase fix: {str(e)}"
SOURCE_MODELS = [
'VOCALS-MelBand-Roformer (by Becruily)',
'VOCALS-Mel-Roformer big beta 4 (by unwa)',
'VOCALS-Melband-Roformer BigBeta5e (by unwa)',
'VOCALS-big_beta6 (by Unwa)',
'VOCALS-big_beta6X (by Unwa)',
'VOCALS-MelBand-Roformer (by KimberleyJSN)',
'VOCALS-MelBand-Roformer Kim FT (by Unwa)',
'VOCALS-MelBand-Roformer Kim FT 2 (by Unwa)',
'VOCALS-MelBand-Roformer Kim FT 2 Blendless (by unwa)',
'VOCALS-Mel-Roformer FT 3 Preview (by unwa)',
'VOCALS-BS-Roformer_1296 (by viperx)',
'VOCALS-BS-Roformer_1297 (by viperx)',
'VOCALS-BS-RoformerLargev1 (by unwa)',
'bs_roformer_revive (by unwa)'
]
TARGET_MODELS = [
'INST-MelBand-Roformer (by Becruily)',
'INST-Mel-Roformer v1 (by unwa)',
'INST-Mel-Roformer v2 (by unwa)',
'inst_v1e (by unwa)',
'INST-Mel-Roformer v1e+ (by unwa)',
'Inst_GaboxV7 (by Gabox)',
'INST-VOC-Mel-Roformer a.k.a. duality (by unwa)',
'INST-VOC-Mel-Roformer a.k.a. duality v2 (by unwa)',
'inst_gabox (by Gabox)',
'inst_gaboxFlowersV10 (by Gabox)'
]