File size: 10,860 Bytes
dbbd709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :Waveformer-main 
@File    :CLAPSep.py
@IDE     :PyCharm 
@Author  :Aisaka/Hao Ma @SDU
@Date    :2024/2/28 下午1:12 
'''

import torch
import laion_clap
from torchmetrics.audio.snr import(
    scale_invariant_signal_noise_ratio as si_snr,
    signal_noise_ratio as snr)
from torchmetrics.audio.sdr import(
    signal_distortion_ratio as sdr,
    scale_invariant_signal_distortion_ratio as si_sdr)
import copy
import loralib as lora
from torchlibrosa import ISTFT, STFT, SpecAugmentation
from torchlibrosa.stft import magphase
import librosa
import pytorch_lightning as pl


def loss_fn(pred, tgt):
    return -0.9 * snr(pred, tgt).mean() - 0.1 * si_snr(pred, tgt).mean()


def set_module(model, submodule_key, module):
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        cur_mod = getattr(cur_mod, s)
    setattr(cur_mod, tokens[-1], module)


def process_model(model, rank):
    for n, module in model.named_modules():
        if 'WindowAttention' in str(type(module)):
            for n_, layer in module.named_modules():
                if isinstance(layer, torch.nn.Linear):
                    lora_layer = lora.Linear(layer.in_features, layer.out_features, r=rank,
                                             bias=hasattr(layer, 'bias'), merge_weights=False)
                    lora_layer.weight = layer.weight
                    if hasattr(layer, 'bias'):
                        lora_layer.bias = layer.bias
                    set_module(model, n+'.'+n_, lora_layer)
    return model


class LightningModule(pl.LightningModule):
    def __init__(self, clap_model, decoder_model, lr, use_lora=False, rank=8, nfft=1024):
        super().__init__()
        self.phase = decoder_model.phase
        self.lr = lr
        self.clap_model = clap_model
        for p in self.clap_model.parameters():
            p.requires_grad = False
        self.audio_branch = copy.deepcopy(self.clap_model.model.audio_branch)
        if use_lora:
            process_model(self.audio_branch, rank)
            lora.mark_only_lora_as_trainable(self.audio_branch, bias='lora_only')

        self.decoder_model = decoder_model
        self.stft = STFT(n_fft=nfft, hop_length=320,
                         win_length=nfft, window='hann', center=True, pad_mode='reflect',
                         freeze_parameters=True)
        self.istft = ISTFT(n_fft=nfft, hop_length=320,
                           win_length=nfft, window='hann', center=True, pad_mode='reflect',
                           freeze_parameters=True)
        self.features = self.install_forward_hooks()

    def training_step(self, batch, batch_idx):
        self.clap_model.eval()
        self.audio_branch.eval()
        # print([len(x) for x in batch])
        mixed, mixed_resample, pos_cap, neg_cap, gt, pos_sample, neg_sample = batch
        real, imag = self.stft(mixed)
        mag, cos, sin = magphase(real, imag)
        with torch.no_grad():
            a = torch.rand((1,)).type_as(gt)
            embed_pos_a, embed_neg_a = torch.chunk(
                self.clap_model.get_audio_embedding_from_data(torch.concat([pos_sample, neg_sample], dim=0),
                                                              use_tensor=True), dim=0, chunks=2)
            embed_pos_t, embed_neg_t = torch.chunk(
                self.clap_model.get_text_embedding(pos_cap + neg_cap, use_tensor=True), dim=0, chunks=2)
            embed_pos = a * embed_pos_a + (1 - a) * embed_pos_t
            embed_neg = a * embed_neg_a + (1 - a) * embed_neg_t
        del self.features[:]
        self.features.append(mag)
        self.audio_branch({"waveform": mixed_resample})
        a = torch.rand((1,))
        if a < 0.25:
            loss = self.cal_loss(embed_pos, torch.zeros_like(embed_pos), mag, cos, sin, length=mixed.size(-1), gt=gt)
        elif a < 0.5:
            loss = self.cal_loss(torch.zeros_like(embed_neg), embed_neg, mag, cos, sin, length=mixed.size(-1), gt=gt)
        else:
            loss = self.cal_loss(embed_pos, embed_neg, mag, cos, sin, length=mixed.size(-1), gt=gt)
        self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, sync_dist=True, batch_size=len(mixed))
        del self.features[:]
        return loss

    def cal_loss(self, embed_p, embed_n, mag, cos, sin, length, gt):
        embed = torch.nn.functional.normalize(torch.concat([embed_p, embed_n], dim=-1), dim=-1)
        mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed)
        pred = self.wav_reconstruct(mask, mag, cos, sin, length=length)
        return loss_fn(pred, gt)

    def wav_reconstruct(self, mask, mag_x, cos_x, sin_x, length):
        # ref: https://github.com/Audio-AGI/AudioSep/blob/main/models/resunet.py
        # Y = |Y|cos∠Y + j|Y|sin∠Y
        #   = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
        #   = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
        if self.phase:
            mag_y = torch.nn.functional.relu_(mag_x * mask[0])
            _, mask_cos, mask_sin = magphase(mask[1], mask[2])
            cos_y = cos_x * mask_cos - sin_x * mask_sin
            sin_y = sin_x * mask_cos + cos_x * mask_sin
        else:
            mag_y = torch.nn.functional.relu_(mag_x * mask)
            cos_y = cos_x
            sin_y = sin_x
        pred = self.istft(mag_y * cos_y, mag_y * sin_y, length=length)
        return pred

    def validation_step(self, batch, batch_idx):
        mixed, mixed_resample, label, neg_label, gt, _, _ = batch
        real, imag = self.stft(mixed)
        mag, cos, sin = magphase(real, imag)
        self.features.append(mag)
        with torch.no_grad():
            embed_pos = self.clap_model.get_text_embedding(label, use_tensor=True)
            embed_neg = self.clap_model.get_text_embedding(neg_label, use_tensor=True)
            embed = torch.concat([embed_pos, embed_neg], dim=-1)
            self.audio_branch({"waveform": mixed_resample})
            mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed)
            pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1))
            loss = si_snr(pred, gt).mean() - si_snr(mixed, gt).mean()
        del self.features[:]
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=len(mixed))
        return {"val_loss": loss}

    def on_test_start(self) -> None:
        self.sdr_vals = torch.tensor([])
        self.sdri_vals = torch.tensor([])
        self.sisdr_vals = torch.tensor([])
        self.sisdri_vals = torch.tensor([])

    def test_step(self, batch, batch_idx):
        mixed, mixed_resample, label, neg_label, gt = batch
        real, imag = self.stft(mixed)
        mag, cos, sin = magphase(real, imag)
        with torch.no_grad():
            embed_pos_bached, embed_neg_bached = torch.chunk(self.clap_model.get_text_embedding(label + neg_label, use_tensor=True), chunks=2, dim=0)
            del self.features[:]
            # only positive
            # embed = torch.concat([embed_pos_bached, torch.zeros_like(embed_neg_bached)], dim=1)
            # only negative
            # embed = torch.concat([torch.zeros_like(embed_pos_bached), embed_neg_bached], dim=1)
            # positive and negative
            embed = torch.concat([embed_pos_bached, embed_neg_bached], dim=1)
            self.features.append(mag)
            self.audio_branch({"waveform": mixed_resample})
            mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed)
            pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1))
            sisdr = si_sdr(pred, gt).cpu()
            self.sisdr_vals = torch.concat([self.sisdr_vals, sisdr])
            self.sisdri_vals = torch.concat([self.sisdri_vals, sisdr - si_sdr(mixed, gt).cpu()])
            sdr_ = sdr(pred, gt).cpu()
            self.sdr_vals = torch.concat([self.sdr_vals, sdr_])
            self.sdri_vals = torch.concat([self.sdri_vals, sdr_ - sdr(mixed, gt).cpu()])
        del self.features[:]
    
    def on_test_end(self) -> None:
        print(f"SDR-mean: {torch.mean(self.sdr_vals).cpu().numpy():.4f}, SDR-std: {torch.std(self.sdr_vals).cpu().numpy():.4f}")
        print(f"SDRi-mean: {torch.mean(self.sdri_vals).cpu().numpy():.4f}, SDRi-std: {torch.std(self.sdri_vals).cpu().numpy():.4f}")
        print(f"SISDR-mean: {torch.mean(self.sisdr_vals).cpu().numpy():.4f}, SISDR-std: {torch.std(self.sisdr_vals).cpu().numpy():.4f}")
        print(f"SISDRi-mean: {torch.mean(self.sisdri_vals).cpu().numpy():.4f}, SISDRi-std: {torch.std(self.sisdri_vals).cpu().numpy():.4f}")
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=5,
                                                                        verbose=True, min_lr=5e-6)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": schedular,
                "interval": "epoch",
                "monitor": "val_loss"
            },
        }

    def install_forward_hooks(self):
        features = []
        spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
                                          freq_drop_width=8, freq_stripes_num=2)

        def get_features_list(_, __, output):
            features.append(output)

        def get_features_list_basic_layer(_, __, output):
            features.append(output[0])

        def spec_augmentation_hook(_, __, out):
            out = out.transpose(1, 3)
            out = spec_augmenter(out)
            return out.transpose(1, 3)

        def spectrogram_padding(_, __, out):
            return torch.nn.functional.pad(out, (0, 0, 0, 1024 - out.size(2)))

        self.clap_model.model.audio_branch.bn0.register_forward_hook(spec_augmentation_hook)
        self.audio_branch.spectrogram_extractor.register_forward_hook(spectrogram_padding)
        self.audio_branch.patch_embed.register_forward_hook(get_features_list)
        for module in self.audio_branch.layers:
            module.register_forward_hook(get_features_list_basic_layer)
        return features

    # # this will only save tuned parameters during training
    # def on_save_checkpoint(self, checkpoint):
    #     weights = checkpoint['state_dict']
    #     new_dict = {}
    #     for k, v in weights.items():
    #         if any(e in k for e in ['lora', 'attn.qkv.bias', 'attn.proj.bias', 'decoder_model']):
    #             new_dict[k] = v
    #     checkpoint['state_dict'] = new_dict