File size: 5,768 Bytes
ddadeb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import torchaudio
import os
import gc

def frequency_blend_phases(phase1, phase2, freq_bins, low_cutoff=500, high_cutoff=5000, base_factor=0.25, scale_factor=1.85):
    if phase1.shape != phase2.shape:
        raise ValueError("phase1 and phase2 must have the same shape.")
    if len(freq_bins) != phase1.shape[0]:
        raise ValueError("freq_bins must have the same length as the number of frequency bins in phase1 and phase2.")
    if low_cutoff >= high_cutoff:
        raise ValueError("low_cutoff must be less than high_cutoff.")

    blended_phase = torch.zeros_like(phase1)
    blend_factors = torch.zeros_like(freq_bins)

    blend_factors[freq_bins < low_cutoff] = base_factor
    blend_factors[freq_bins > high_cutoff] = base_factor + scale_factor

    in_range_mask = (freq_bins >= low_cutoff) & (freq_bins <= high_cutoff)
    blend_factors[in_range_mask] = base_factor + scale_factor * (
        (freq_bins[in_range_mask] - low_cutoff) / (high_cutoff - low_cutoff)
    )

    for i in range(phase1.shape[0]):
        blended_phase[i, :] = (1 - blend_factors[i]) * phase1[i, :] + blend_factors[i] * phase2[i, :]

    blended_phase = torch.remainder(blended_phase + torch.pi, 2 * torch.pi) - torch.pi

    return blended_phase

def transfer_magnitude_phase(source_file, target_file, output_folder, transfer_magnitude=False, transfer_phase=True, 
                              low_cutoff=500, high_cutoff=9000, scale_factor=1.4, output_format='flac'):
    target_name, target_ext = os.path.splitext(os.path.basename(target_file))
    
    target_name = target_name.replace("_other", "").replace("_vocals", "").replace("_instrumental", "")
    target_name = target_name.replace("_Other", "").replace("_Vocals", "").replace("_Instrumental", "").strip()
    
    ext = '.flac' if output_format == 'flac' else '.wav'
    output_file = os.path.join(output_folder, f"{target_name} (Fixed Instrumental){ext}")

    print(f"Phase Fixing: {os.path.basename(target_file)}...")
    source_waveform, source_sr = torchaudio.load(source_file)
    target_waveform, target_sr = torchaudio.load(target_file)

    if source_sr != target_sr:
        raise ValueError("Sample rates of source and target audio files must match.")

    n_fft = 2048
    hop_length = 512
    window = torch.hann_window(n_fft)

    source_stfts = torch.stft(source_waveform, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True, pad_mode="reflect")
    target_stfts = torch.stft(target_waveform, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True, pad_mode="reflect")

    freqs = torch.linspace(0, source_sr // 2, steps=n_fft // 2 + 1)

    modified_stfts = []
    for source_stft, target_stft in zip(source_stfts, target_stfts):
        source_mag, source_phs = torch.abs(source_stft), torch.angle(source_stft)
        target_mag, target_phs = torch.abs(target_stft), torch.angle(target_stft)

        modified_stft = target_stft.clone()
        if transfer_magnitude:
            modified_stft = source_mag * torch.exp(1j * torch.angle(modified_stft))

        if transfer_phase:
            blended_phase = frequency_blend_phases(target_phs, source_phs, freqs, low_cutoff, high_cutoff, scale_factor=scale_factor)
            modified_stft = torch.abs(modified_stft) * torch.exp(1j * blended_phase)

        modified_stfts.append(modified_stft)

    modified_waveform = torch.istft(
        torch.stack(modified_stfts),
        n_fft=n_fft,
        hop_length=hop_length,
        window=window,
        length=source_waveform.size(1)
    )

    if output_format == 'flac':
        torchaudio.save(output_file, modified_waveform, target_sr, format="flac", bits_per_sample=16)
    else:
        torchaudio.save(output_file, modified_waveform, target_sr)
    
    print(f"Saved: {output_file}")
    return output_file

def process_phase_fix(source_file, target_file, output_folder, low_cutoff=500, high_cutoff=9000, 
                      scale_factor=1.4, output_format='flac'):
    os.makedirs(output_folder, exist_ok=True)
    
    try:
        output_file = transfer_magnitude_phase(
            source_file=source_file,
            target_file=target_file,
            output_folder=output_folder,
            transfer_magnitude=False,
            transfer_phase=True,
            low_cutoff=low_cutoff,
            high_cutoff=high_cutoff,
            scale_factor=scale_factor,
            output_format=output_format
        )
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        return output_file, "Phase fix completed successfully!"
    except Exception as e:
        return None, f"Error during phase fix: {str(e)}"

SOURCE_MODELS = [
    'VOCALS-MelBand-Roformer (by Becruily)',
    'VOCALS-Mel-Roformer big beta 4 (by unwa)',
    'VOCALS-Melband-Roformer BigBeta5e (by unwa)',
    'VOCALS-big_beta6 (by Unwa)',
    'VOCALS-big_beta6X (by Unwa)',
    'VOCALS-MelBand-Roformer (by KimberleyJSN)',
    'VOCALS-MelBand-Roformer Kim FT (by Unwa)',
    'VOCALS-MelBand-Roformer Kim FT 2 (by Unwa)',
    'VOCALS-MelBand-Roformer Kim FT 2 Blendless (by unwa)',
    'VOCALS-Mel-Roformer FT 3 Preview (by unwa)',
    'VOCALS-BS-Roformer_1296 (by viperx)',
    'VOCALS-BS-Roformer_1297 (by viperx)',
    'VOCALS-BS-RoformerLargev1 (by unwa)',
    'bs_roformer_revive (by unwa)'
]

TARGET_MODELS = [
    'INST-MelBand-Roformer (by Becruily)',
    'INST-Mel-Roformer v1 (by unwa)',
    'INST-Mel-Roformer v2 (by unwa)',
    'inst_v1e (by unwa)',
    'INST-Mel-Roformer v1e+ (by unwa)',
    'Inst_GaboxV7 (by Gabox)',
    'INST-VOC-Mel-Roformer a.k.a. duality (by unwa)',
    'INST-VOC-Mel-Roformer a.k.a. duality v2 (by unwa)',
    'inst_gabox (by Gabox)',
    'inst_gaboxFlowersV10 (by Gabox)'
]